Pytorch读取模型文件及视频的维度膨胀

在深度学习中,我们载入预训练模型时,经常要查看预训练模型的内容,以便更好使用,我以pytorch为例总结一下简单的操作

模型文件的读取:

pretrained_dict = torch.load(path)#读取下来的pretrained_dict是一个字典

获取模型保存的键值对

for k, v in pretrained_dict.items():

这里的k一般是层名,v是获取的对应层名的内容(有些模型可能保存的形式不同)

获取每一层的权重,并转化成numpy数组

得到的v是parameter格式,包含权重和是否反向传播

我们可以通过v.data得到权重数值,类型为tensor,然后通过v.data.numpy()转化成array,使用v.requires_grad获取是否反向传播的True/False

将numpy转化为parameter:

new_v =nn.Parameter(torch.tensor(array))

保存模型:

torch.save(dict,new_path)

其中dict是字典名,包含键为层名,值为parameter

在视频行为识别中,常常缺少预训练模型,需要图像的预训练模型进行维度膨胀,因此我参考i3d的处理方式,将resnet50的预训练模型进行维度膨胀,仅供参考:

import torch
from torchvision import models
import numpy as np
import torch.nn as nn

pretrained_dict =torch.load(r'/data/zhengrui/dataset/pretrain/resnet50-19c8e357.pth',map_location='cpu')
new_dict = {}
for k, v in pretrained_dict.items():
	new_dict = {}
	print('layer:', k,'\n')
	print('content:',v.data.numpy().shape,'\n')
	print(v.requires_grad)

   	conv_weight = v.data.numpy()
	if 'conv' in k:
       try:
			n = conv_weight.shape[3]
			print(n)
            v1 = np.tile(np.expand_dims(conv_weight / n, 2), [1, 1, n, 1, 1])
			print(v1.shape)
            new_v = nn.Parameter(torch.tensor(v1))
            new_v.requires_grad = v.requires_grad
       except IndexError:
       print('ERROR:',k)   
	elif k == 'layer1.0.downsample.0.weight' or k ==
'layer2.0.downsample.0.weight' or k == 'layer3.0.downsample.0.weight' or k ==
'layer4.0.downsample.0.weight':

       v1 = v.data.unsqueeze(4)
       new_v = nn.Parameter(v1)
       new_v.requires_grad = v.requires_grad

   else:
        new_v = v

  new_dict[k] = new_v

#print(new_dict)
torch.save(new_dict,r'/data/zhengrui/dataset/pretrain/resnet50-2dto3d.pth')
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值