pytorch 之 加载不同形式的预训练模型

我们在学习pytorch时,不可避免的要加载不同的预训练模型。而且pytorch下的预训练模型有很多种形式,我们又该如何加载呢。今天,我就为大家介绍三种常用的模型形式以及其加载方式。

1.pth形式和.pt形式,这种形式的模型使我们较为常见的形式,保存的可以是整个模型含模型结构及参数)只保存参数
当保存的是整个模型时,读取方式有两种:
   1).直接加载:

if __name__=='__main__':
    with torch.no_grad():
        model=InceptionI3d(num_classes=400,in_channels = 3)
        if not os.path.exists('./rgb_imagenet.pt'):
            print ('No weights Found! please download first, or comment 382~384th line')
        
        model.load_state_dict('./rgb_imagenet.pt')#通过load_state_dict()函数来加载

   2).通过state_dict()函数先读参数再保存

if __name__=='__main__':
    with torch.no_grad():
        model=InceptionI3d(num_classes=400,in_channels = 3)
        if not os.path.exists('./rgb_imagenet.pt'):
            print ('No weights Found! please download first, or comment 382~384th line')
        
        weight = torch.load("./rgb_imagenet.pt")['model']# 'model'为字典键值
        model.load_state_dict(weight.state_dict())

当仅保存模型参数时: 这种形式一般保存的是一个字典型state_dict:权重,所以要加工一次。

if __name__ == "__main__":
    with torch.no_grad():
        net = TSN(num_class, this_test_segments if is_shift else 1, modality,
                      base_model='resnet50',
                      consensus_type = crop_fusion_type,
                      img_feature_dim = img_feature_dim,
                      pretrain = pretrain,
                      is_shift=is_shift, shift_div=shift_div, shift_place=shift_place,
                      non_local='_nl' in this_weights,
                      )
        weights = 'TSM_something_RGB_resnet50_shift8_blockres_avg_segment16_e45.pth'#指定路径
        
        checkpoint = torch.load(this_weights)#通过load函数读出来
        checkpoint = checkpoint['state_dict']#取出state_sict所对应的权重名称和权重,这里有可能字典键值不是state_dict,可以使用.keys()函数查看具体键值
        base_dict = {'.'.join(k.split('.')[1:]): v for k, v in list(checkpoint.items())}#一般是为了除去权重名称前的module.前缀,这个可根据自己的需要添加。也可以打印出所有的权重名称.items()。
        
        net.load_state_dict(base_dict)#同样用load_state_dict加载。

2.pth.tar形式,这种形式的模型保存的通常也是字典型,但是不仅仅state_sict一项,可以用print(set(checkpoint))打印查看,我们可以不管其他内容,因此和2的处理方式基本相同。

if __name__ == "__main__":
    with torch.no_grad():
        model = MultiColumn(174, Model, 512)
        checkpoint_path = './model_best.pth.tar'#指定路径
        checkpoint = torch.load(checkpoint_path)#使用load函数读值
        #print(set(checkpoint))
        checkpoint = checkpoint['state_dict']
        base_dict = {'.'.join(k.split('.')[1:]): v for k, v in list(checkpoint.items())}#一般是为了除去权重名称前的module.前缀,这个可根据自己的需要添加。可以打印出所有的权重称.items()。
        model.load_state_dict(checkpoint['state_dict'])#使用load_state_dict()加载权重。

这里的介绍就先到这里啦,大家要学会灵活运用。

  • 0
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值