在复现目标检测模型DETR时,遇到
原因:
根据搜索得到的结果是因为训练模型时的torch版本和测试时的不一致,其实是在使用mydataset生成预训练文件时设置的类别参数不一致导致的。源代码由于用的是coco数据集,所以用的应该是91种类别,这里根据自己数据集的类别数进行调整就行。
我的目标是要检测3种类别,所以我将mydataset的num_classes设置为3,mydataset代码如下:
import torch
pretrained_weights = torch.load('detr-r50-e632da11.pth')
num_class = 3 #类别数+1,1为背景
pretrained_weights["model"]["class_embed.weight"].resize_(num_class+1, 256)
pretrained_weights["model"]["class_embed.bias"].resize_(num_class+1)
torch.save(pretrained_weights, "detr-r50_%d.pth"%num_class)
运行后生成一个.pth文件
然后修改detr.py中的num_classes参数跟预训练文件的参数保持一致就可以了,具体修改方法看DETR复现。然后根据自己的数据集进行修改,让训练与测试时的参数一致即可。
def build(args):
# the `num_classes` naming here is somewhat misleading.
# it indeed corresponds to `max_obj_id + 1`, where max_obj_id
# is the maximum id for a class in your dataset. For example,
# COCO has a max_obj_id of 90, so we pass `num_classes` to be 91.
# As another example, for a dataset that has a single class with id 1,
# you should pass `num_classes` to be 2 (max_obj_id + 1).
# For more details on this, check the following discussion
# https://github.com/facebookresearch/detr/issues/108#issuecomment-650269223
'''num_classes = 9 if args.dataset_file != 'coco' else 91
if args.dataset_file == "coco_panoptic":
# for panoptic, we just add a num_classes that is large enough to hold
# max_obj_id + 1, but the exact value doesn't really matter
num_classes = 250'''
num_classes = 3
然后再运行main.py即可。