解决DETR模型参数不匹配size mismatch for class_embed.weight:~([4, 256]) from checkpoint,~current([3,256]).

在复现目标检测模型DETR时,遇到

原因:

根据搜索得到的结果是因为训练模型时的torch版本和测试时的不一致,其实是在使用mydataset生成预训练文件时设置的类别参数不一致导致的。源代码由于用的是coco数据集,所以用的应该是91种类别,这里根据自己数据集的类别数进行调整就行。

我的目标是要检测3种类别,所以我将mydataset的num_classes设置为3,mydataset代码如下:

import torch
pretrained_weights  = torch.load('detr-r50-e632da11.pth')


num_class = 3    #类别数+1,1为背景
pretrained_weights["model"]["class_embed.weight"].resize_(num_class+1, 256)
pretrained_weights["model"]["class_embed.bias"].resize_(num_class+1)
torch.save(pretrained_weights, "detr-r50_%d.pth"%num_class)

运行后生成一个.pth文件

 然后修改detr.py中的num_classes参数跟预训练文件的参数保持一致就可以了,具体修改方法看DETR复现。然后根据自己的数据集进行修改,让训练与测试时的参数一致即可。

def build(args):
    # the `num_classes` naming here is somewhat misleading.
    # it indeed corresponds to `max_obj_id + 1`, where max_obj_id
    # is the maximum id for a class in your dataset. For example,
    # COCO has a max_obj_id of 90, so we pass `num_classes` to be 91.
    # As another example, for a dataset that has a single class with id 1,
    # you should pass `num_classes` to be 2 (max_obj_id + 1).
    # For more details on this, check the following discussion
    # https://github.com/facebookresearch/detr/issues/108#issuecomment-650269223
    '''num_classes = 9 if args.dataset_file != 'coco' else 91
    if args.dataset_file == "coco_panoptic":
        # for panoptic, we just add a num_classes that is large enough to hold
        # max_obj_id + 1, but the exact value doesn't really matter
        num_classes = 250'''
    num_classes = 3

然后再运行main.py即可。

  • 3
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值