size mismatch for head.weight: copying a param with shape torch.Size([1000, 768]) from checkpoint

问题描述


我在github上的vit分类模型(https://github.com/YINYIPENG-EN/vit_classification_pytorch)使用了自己的数据集,训练的时候发生报错。

RuntimeError: Error(s) in loading state_dict for VisionTransformer:
        size mismatch for head.weight: copying a param with shape torch.Size([1000, 768]) from checkpoint, the shape in current model is torch.Size([3, 768]).
        size mismatch for head.bias: copying a param with shape torch.Size([1000]) from checkpoint, the shape in current model is torch.Size([3]).

报错原因:训练模型时出现FC层不匹配,我的数据集是 3个类的,下载的预权重是ImageNet 1000个类的。

解决方法

# 修改前的代码  train.py
def train(opt):
    classes_path = 'weights/cls_classes.txt'
    pretrained = False
    val_split = 0.1
    class_names, num_classes = get_classes(classes_path)
    if opt.model == 'vit':
        model = vit(num_classes=num_classes, pretrained=pretrained)
    if not pretrained:
        weights_init(model)
    if opt.weight != '':
        print('Loading {} into state dict...'.format(opt.weight))
        device = torch.device('cuda' if opt.cuda else 'cpu')
        model_dict = model.state_dict()
        pretrained_dict = torch.load(opt.weight, map_location=device)
        pretrained_dict = {k: v for k, v in pretrained_dict.items() if pretrained_dict.keys() == model_dict.keys()}
        model_dict.update(pretrained_dict)
        model.load_state_dict(model_dict)
    model_train = model.train()

方法一:将fc层忽略

找到pth文件导入的地方(即ckpt),将全连接层的参数直接pop掉

修改后的代码
def train(opt):
    classes_path = 'weights/cls_classes.txt'
    pretrained = False
    val_split = 0.1
    class_names, num_classes = get_classes(classes_path)
    if opt.model == 'vit':
        model = vit(num_classes=num_classes, pretrained=pretrained)
    if not pretrained:
        weights_init(model)
    if opt.weight != '':
        print('Loading {} into state dict...'.format(opt.weight))
        device = torch.device('cuda' if opt.cuda else 'cpu')
        # vit-patch_16.pth 是我的权重文件,你们要改成你自己的
        ckpt = torch.load('weights/vit-patch_16.pth', map_location='cpu') 
        ckpt.pop('head.weight')
        ckpt.pop('head.bias')
        model_dict = model.state_dict()
        pretrained_dict = torch.load(opt.weight, map_location=device)
        pretrained_dict = {k: v for k, v in pretrained_dict.items() if pretrained_dict.keys() == model_dict.keys()}
        model_dict.update(pretrained_dict)
        model.load_state_dict(ckpt, strict=False)
    model_train = model.train()

方法二:减去参数不匹配的层

def train(opt):
    classes_path = 'weights/cls_classes.txt'
    pretrained = False
    val_split = 0.1
    class_names, num_classes = get_classes(classes_path)
    if opt.model == 'vit':
        model = vit(num_classes=num_classes, pretrained=pretrained)
    if not pretrained:
        weights_init(model)
    if opt.weight != '':
        print('Loading {} into state dict...'.format(opt.weight))
        device = torch.device('cuda' if opt.cuda else 'cpu')
        model_dict = model.state_dict()
        pretrained_dict = torch.load(opt.weight, map_location=device)
       # 减去参数不匹配的层,我的层为“head”
        pretrained_dict = {k: v for k, v in pretrained_dict.items() if (k in model_dict and 'head' not in k)}
        model_dict.update(pretrained_dict)
        model.load_state_dict(model_dict)
    model_train = model.train()

  • 5
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值