pytorch模型加载的各种情况代码整理

pytorch模型加载

基本的情形:

一、单卡训练好的模型文件,推理阶段的加载(即加载模型文件和model定义的是一致的)
state_dict = torch.load('checkpoint.pth.tar')
Mymodel.load_state_dict(state_dict)
二、多gpu训练保存的模型,在单卡情况下加载

当出现以第一种的写法会有问题时,自然的就需要检查Mymodel和要加载的模型文件的差异

# 打印你自己定义的模型的键值对
params = Mymodel.state_dict()  # 获得模型的原始状态以及参数, orderdict数据类型
for k, v in params.items():
    print(k)  # 只打印key值,不打印具体参数值
    
# 打印加载的checkpoint文件的键值对
state_dict = torch.load('checkpoint.pth.tar')
for k, v in checkpoint.items():
    print(k)

然后你就会发现有些key值多了module,这时候我们只需要去掉多余的名字使得两者的key值相同就可以正常加载了。主要有3种:

  • 从key中的第7个字符开始取

    from collections import OrderedDict
    state_dict = torch.load('checkpoint.pth.tar')
    new_state_dict = OrderedDict()
    for k,v in state_dict.items():
        name = k[7:] # remove `module`
        new_state_dict[name] = v
    # load params
    Mymodel.load_state_dict(new_state_dict)
    
  • module 替换为空字符

    name = k.replace('module.', '')
    
  • 最简单的方法,加载模型之后,接着将模型DataParallel,此时就可以load_state_dict。
    如果有多个GPU,将模型并行化,用DataParallel来操作。这个过程会将key值加一个module.

    Mymodel = resnet18()# 实例化自己的模型;
    state_dict = torch.load('checkpoint.pt', map_location='cpu') 
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model) 
    model.load_state_dict(checkpoint) # 可以直接将模型参数load进模型。
    
三、加载模型文件与当前构建模型相同部分层的权重

主要有两种写法:

  1. 在写模型类时,在训练好的模型基础上搭建自己的模型。即先带预训练模型建model,然后再修改、删除或者添加模块;

    featureExtract = resnet18(pretrained=True) # load weight
    self.featureEncoder = nn.Sequential(*list(featureExtract.children())[:-2])
    
  2. 整个模型类已经实现,调用后加载部分层参数

    使用前最好打印检查下,因为设置了strict=false,如果只差了module是不会报错的。

    #第一种方法:
    mymodelB = TheModelBClass(*args, **kwargs)
    # strict=False,设置为false,只保留键值相同的参数
    mymodelB.load_state_dict('checkpoint.pt', strict=False)
    
    #第二种方法:
    # 加载模型
    model_pretrained = torch.load('checkpoint.pt')
    
    # mymodel's state_dict,
    # 如:  conv1.weight 
    #     conv1.bias  
    mymodelB_dict = mymodelB.state_dict()
    
    # 将model_pretrained的建与自定义模型的建进行比较,剔除不同的
    pretrained_dict = {k: v for k, v in model_pretrained.items() if k in mymodelB_dict}
    # 更新现有的model_dict
    mymodelB_dict.update(pretrained_dict)
    
    # 加载我们真正需要的state_dict
    mymodelB.load_state_dict(mymodelB_dict)
    
参考

https://zhuanlan.zhihu.com/p/48524007

讨论:checkpoint 里的module导致键值不匹配

参考blog1

参考bolg2

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值