1 模型代码来源

我使用的模型是B站博主@霹雳吧啦Wz代码,他写模型的读取方法时考虑的是pascal voc这个数据集,因此如果我们如果使用其他数据集,难免会遇见一些问题,这里记录了我是如何解决的。
ps: 如果使用ultralytics版本的yolov8,他们的代码会有忽略错误标框的能力。

2 对Wider Face数据集的训练处理

2.1 转换为voc格式

参考这篇的做法,一套流程下来,voc格式,coco格式和yolo格式都有了。

2.2 错误标框导致训练不能进行的问题

我一开始时用他仓库目录下的retinanet这个模型,代码有一定的容错能力,在碰见宽度/长度为0的框时会raise error,具体的代码在retinaNet/network_files/retinanet.py482-493

        if targets is not None:
            for target_idx, target in enumerate(targets):
                boxes = target["boxes"]
                degenerate_boxes = boxes[:, 2:] <= boxes[:, :2]
                if degenerate_boxes.any():
                    continue # <----这个continue是我加的,不然训练会停
                    # print the first degenerate box
                    bb_idx = torch.where(degenerate_boxes.any(dim=1))[0][0]
                    degen_bb: List[float] = boxes[bb_idx].tolist()
                    raise ValueError("All bounding boxes should have positive height and width."
                                     " Found invalid box {} for target at index {}."
                                     .format(degen_bb, target_idx))

所以我直接在循环条件里加了continue看看能不能强行训练,后面发现不行(到某个step损失为Inf),因为实际有问题的图片大概有30张,我的batch为8,若batch大于30或许能强行训练,总之需要挑出有问题的图片,并且这一步加的continue不要删。

2.2.1 挑出误标图片

直接上代码吧,最后会有一个err_xml.txt文件显示有问题的图片。按照它,把voc格式的wider face的train.txt中的对应路径给删除。
我判断误标的条件是

  • 长or宽为0
  • 长or宽大于图片尺寸

我怀疑这个标准不够严,因为训练中还是会raise error,所以我把上一步的continue给留着了,后续损失下降,训练成功,可见留下来的错误图片并不会严重影响模型效果。

import os 
import xml.etree.ElementTree as ET
from tqdm import tqdm 

def find_dbox(root):
    err_xmls = open('./err_xmls.txt', 'a')
    xml_files=os.listdir(os.path.abspath(root))
    xml_files.sort()
    print(f'{len(xml_files)} files found')

    err_count =0
    for xml_file in xml_files:
        xml_path = os.path.join(root,xml_file)
        tree = ET.parse(xml_path)
        
        width = int(tree.find('size/width').text)
        height = int(tree.find('size/height').text)
        
        boxes = tree.findall('object/bndbox')
        
        for i,box in enumerate(boxes,start=1):
            x0,y0,x1,y1 = int(box.find('xmin').text),int(box.find('ymin').text),\
            int(box.find('xmax').text),int(box.find('ymax').text)

            if x0 >=x1 or y0>=y1 or (x1-x0)>width or (y1-y0)>height:
                print(xml_path, file=err_xmls)
                # 反注释可以看出是图片里哪个框有问题             
                # print("************************")
                # print(i)
                # print("************************")
                err_count +=1
                break
            
    print(f'there are {err_count} files')
    err_xmls.close()
        
        
if __name__ == '__main__':
    find_dbox(root='./widerface/face2voc/Annotations')
    print('\033[1;32m程序完成\n¯\_(ツ)_/¯\033[0m')

2.2.2 对读取数据集代码的处理

改动my_dataset.py,其实就是各种路径问题,并且需要一个xxx.json文件记录各种类别和index的关系。对于wider face来说,类别就1个:人脸。所以json文件如下

{
	"face": 0
}
class VOCDataSet(Dataset):
    """读取解析PASCAL VOC2007/2012数据集"""

    def __init__(self, voc_root, year="2012", transforms=None, txt_name: str = "train.txt"):
        assert year in ["2007", "2012"], "year must be in ['2007', '2012']"
        # 增加容错能力
        # if "VOCdevkit" in voc_root:
        #     self.root = os.path.join(voc_root, f"VOC{year}")
        # else:
        #     self.root = os.path.join(voc_root, "VOCdevkit", f"VOC{year}")
        self.root = os.path.abspath(voc_root)
        self.img_root = os.path.join(self.root, "JPEGImages")
        self.annotations_root = os.path.join(self.root, "Annotations")

        # read train.txt or val.txt file
        # txt_path = os.path.join(self.root, "ImageSets", "Main", txt_name)
        txt_path = os.path.join(self.root, 'Labels', txt_name)
        assert os.path.exists(txt_path), "not found {} file.".format(txt_name)

        with open(txt_path) as read:
            self.xml_list = [os.path.join(self.annotations_root, line.strip() + ".xml")
                             for line in read.readlines() if len(line.strip()) > 0]

        # check file
        assert len(self.xml_list) > 0, "in '{}' file does not find any information.".format(txt_path)
        for xml_path in self.xml_list:
            assert os.path.exists(xml_path), "not found '{}' file.".format(xml_path)

        # read class_indict
        # json_file = './pascal_voc_classes.json'
        json_file = os.path.join(self.root, 'wider_face_voc.json') 
        assert os.path.exists(json_file), "{} file not exist.".format(json_file)
        with open(json_file, 'r') as f:
            self.class_dict = json.load(f)

        self.transforms = transforms

2.3 训练代码的改动

博主@霹雳吧啦Wz提供了单卡训练的train.py和多卡训练的train_multi_GPU.py,基本改动无非文件路径batchepochs等,较大的改动就是每个epoch完成后的验证集部分,因为他写了一个train_eval_utils.py去计算模型在验证集上的表现,仿照的是pycocotools的结果,总之有些问题,一时半会还需要仔细看看进行修改。我注释掉了验证的步骤,让模型直接训练,每轮下来只记录损失和学习率。最后取损失最低的模型参数。

3 成果展示

经过这一通魔改,本来对结果不太抱有希望,毕竟也不知道模型的mAP等指标到底是多少,尝试送了两张照片进行预测,结果还可以,起码能找出人脸hhhhh
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
后面可能还会继续改一下验证的部分,给自己挖个坑。

未完待续…

Logo

魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。

更多推荐