使用wider face/自己的数据集训练faster rcnn等模型的记录
所以我直接在循环条件里加了continue看看能不能强行训练,后面发现不行(到某个step损失为Inf),因为实际有问题的图片大概有30张,我的batch为8,若batch大于30或许能强行训练,总之需要挑出有问题的图片,并且这一步加的continue不要删。我怀疑这个标准不够严,因为训练中还是会raise error,所以我把上一步的continue给留着了,后续损失下降,训练成功,可见留下来的
文章目录
1 模型代码来源
我使用的模型是B站博主@霹雳吧啦Wz的代码,他写模型的读取方法时考虑的是pascal voc这个数据集,因此如果我们如果使用其他数据集,难免会遇见一些问题,这里记录了我是如何解决的。
ps: 如果使用ultralytics版本的yolov8,他们的代码会有忽略错误标框的能力。
2 对Wider Face数据集的训练处理
2.1 转换为voc格式
参考这篇的做法,一套流程下来,voc格式,coco格式和yolo格式都有了。
2.2 错误标框导致训练不能进行的问题
我一开始时用他仓库目录下的retinanet这个模型,代码有一定的容错能力,在碰见宽度/长度为0的框时会raise error,具体的代码在retinaNet/network_files/retinanet.py
的482-493
行
if targets is not None:
for target_idx, target in enumerate(targets):
boxes = target["boxes"]
degenerate_boxes = boxes[:, 2:] <= boxes[:, :2]
if degenerate_boxes.any():
continue # <----这个continue是我加的,不然训练会停
# print the first degenerate box
bb_idx = torch.where(degenerate_boxes.any(dim=1))[0][0]
degen_bb: List[float] = boxes[bb_idx].tolist()
raise ValueError("All bounding boxes should have positive height and width."
" Found invalid box {} for target at index {}."
.format(degen_bb, target_idx))
所以我直接在循环条件里加了continue看看能不能强行训练,后面发现不行(到某个step损失为Inf),因为实际有问题的图片大概有30张,我的batch为8,若batch大于30或许能强行训练,总之需要挑出有问题的图片,并且这一步加的continue不要删。
2.2.1 挑出误标图片
直接上代码吧,最后会有一个err_xml.txt
文件显示有问题的图片。按照它,把voc格式的wider face的train.txt
中的对应路径给删除。
我判断误标的条件是:
- 长or宽为0
- 长or宽大于图片尺寸
我怀疑这个标准不够严,因为训练中还是会raise error,所以我把上一步的continue给留着了,后续损失下降,训练成功,可见留下来的错误图片并不会严重影响模型效果。
import os
import xml.etree.ElementTree as ET
from tqdm import tqdm
def find_dbox(root):
err_xmls = open('./err_xmls.txt', 'a')
xml_files=os.listdir(os.path.abspath(root))
xml_files.sort()
print(f'{len(xml_files)} files found')
err_count =0
for xml_file in xml_files:
xml_path = os.path.join(root,xml_file)
tree = ET.parse(xml_path)
width = int(tree.find('size/width').text)
height = int(tree.find('size/height').text)
boxes = tree.findall('object/bndbox')
for i,box in enumerate(boxes,start=1):
x0,y0,x1,y1 = int(box.find('xmin').text),int(box.find('ymin').text),\
int(box.find('xmax').text),int(box.find('ymax').text)
if x0 >=x1 or y0>=y1 or (x1-x0)>width or (y1-y0)>height:
print(xml_path, file=err_xmls)
# 反注释可以看出是图片里哪个框有问题
# print("************************")
# print(i)
# print("************************")
err_count +=1
break
print(f'there are {err_count} files')
err_xmls.close()
if __name__ == '__main__':
find_dbox(root='./widerface/face2voc/Annotations')
print('\033[1;32m程序完成\n¯\_(ツ)_/¯\033[0m')
2.2.2 对读取数据集代码的处理
改动my_dataset.py
,其实就是各种路径问题,并且需要一个xxx.json
文件记录各种类别和index的关系。对于wider face来说,类别就1个:人脸。所以json
文件如下
{
"face": 0
}
class VOCDataSet(Dataset):
"""读取解析PASCAL VOC2007/2012数据集"""
def __init__(self, voc_root, year="2012", transforms=None, txt_name: str = "train.txt"):
assert year in ["2007", "2012"], "year must be in ['2007', '2012']"
# 增加容错能力
# if "VOCdevkit" in voc_root:
# self.root = os.path.join(voc_root, f"VOC{year}")
# else:
# self.root = os.path.join(voc_root, "VOCdevkit", f"VOC{year}")
self.root = os.path.abspath(voc_root)
self.img_root = os.path.join(self.root, "JPEGImages")
self.annotations_root = os.path.join(self.root, "Annotations")
# read train.txt or val.txt file
# txt_path = os.path.join(self.root, "ImageSets", "Main", txt_name)
txt_path = os.path.join(self.root, 'Labels', txt_name)
assert os.path.exists(txt_path), "not found {} file.".format(txt_name)
with open(txt_path) as read:
self.xml_list = [os.path.join(self.annotations_root, line.strip() + ".xml")
for line in read.readlines() if len(line.strip()) > 0]
# check file
assert len(self.xml_list) > 0, "in '{}' file does not find any information.".format(txt_path)
for xml_path in self.xml_list:
assert os.path.exists(xml_path), "not found '{}' file.".format(xml_path)
# read class_indict
# json_file = './pascal_voc_classes.json'
json_file = os.path.join(self.root, 'wider_face_voc.json')
assert os.path.exists(json_file), "{} file not exist.".format(json_file)
with open(json_file, 'r') as f:
self.class_dict = json.load(f)
self.transforms = transforms
2.3 训练代码的改动
博主@霹雳吧啦Wz提供了单卡训练的train.py
和多卡训练的train_multi_GPU.py
,基本改动无非文件路径
、batch
、epochs
等,较大的改动就是每个epoch完成后的验证集部分,因为他写了一个train_eval_utils.py
去计算模型在验证集上的表现,仿照的是pycocotools
的结果,总之有些问题,一时半会还需要仔细看看进行修改。我注释掉了验证的步骤,让模型直接训练,每轮下来只记录损失和学习率。最后取损失最低的模型参数。
3 成果展示
经过这一通魔改,本来对结果不太抱有希望,毕竟也不知道模型的mAP等指标到底是多少,尝试送了两张照片进行预测,结果还可以,起码能找出人脸hhhhh
后面可能还会继续改一下验证的部分,给自己挖个坑。
未完待续…

魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。
更多推荐
所有评论(0)