将caffe预训练模型的权重载入pytorch

6 篇文章 0 订阅
2 篇文章 0 订阅
#!/usr/bin/env python2.7
#coding=utf-8

import caffe
import csv
import numpy as np
# np.set_printoptions(threshold='nan')

MODEL_FILE = 'inception_v3_rgb_deploy.prototxt'
PRETRAIN_FILE = 'inception_v3_kinetics_rgb_pretrained.caffemodel'

net = caffe.Net(MODEL_FILE, PRETRAIN_FILE, caffe.TEST)

p = []
for param_name in net.params.keys():
	# print(param_name)
	weight = net.params[param_name][0].data
    bias = net.params[param_name][1].data
    p.append(weight)
    p.append(bias)
np.save('params.npy', p)

存为numpy之前,要分析网络架构(层名称等)是否对应。

#!/usr/bin/env python3.5
# load net :

"""
https://github.com/pytorch/vision/blob/master/torchvision/models/inception.py

caffemodel has pretrained conv.bias and torchvision doesn't, so change L35, L321

35  def __init__(self, num_classes=400, aux_logits=False, transform_input=False):

320 super(BasicConv2d, self).__init__()
321 self.conv = nn.Conv2d(in_channels, out_channels, bias=True, **kwargs)
"""

net1 = Inception3()
from collections import OrderedDict
from torch import Tensor

new_params = np.load('params.npy',encoding='bytes')
dict_new = net1.state_dict().copy() # isinstance(dict_new, OrderedDict) is True
new_list = list(net1.state_dict().keys())


ind_new = 0
for i, k in enumerate(new_list):
    # print(i, k, ind_new, new_list[i])
    # use this when num_class != 400 , it won't load fc's weight and bias
    # if k.split('.')[-1] not in ['weight', 'bias'] or k.split('.')[-2] == 'fc':
    
    if k.split('.')[-1] not in ['weight', 'bias']:
           continue
    # check shape
    tmp = new_params[ind_new].reshape((net1.state_dict()[new_list[i]]).shape)
    dict_new[ new_list[i] ] = Tensor(tmp)
    # print(Tensor(new_params[ind_new]).shape)
    ind_new += 1

net1.load_state_dict(dict_new)
torch.save(net1.state_dict(), 'pretrain_v3_params.pkl')
print('saved done.')
# use pretrained params
net2 = Inception3()
net2.load_state_dict(torch.load('pretrain_v3_params.pkl'))
print('loaded done.')

第一段代码用Python2.7的caffe进行提取保存权重,第二段代码用python3.5的Pytorch进行加载。

以下是坑:

  1. Caffe在Python3.5下我没编译成功。。
  2. 中间文件读取编码方式,Py2与Py3不同
  3. 模型参数名称和Shape要对应

参考:

https://blog.csdn.net/u011762313/article/details/49851795
https://zhuanlan.zhihu.com/p/34147880

# BN_inc

from collections import OrderedDict
from torch import Tensor
import numpy as np

new_params = np.load('/home/yaotiechui/pytorch-caffe/bn_params22.npy',encoding='bytes')
dict_new = net1.state_dict().copy() # isinstance(dict_new, OrderedDict) is True
new_list = list(net1.state_dict().keys())


ind_new
 = 0
for i, k in enumerate(new_list):
#     print(i, k, ind_new, new_list[i])
    # inception_5b_pool_proj_bn.num_batches_tracked == [], Tensor(0)
    print(new_list[i],net1.state_dict()[new_list[i]].dim(), )
    if net1.state_dict()[new_list[i]].dim() == 0:     
        break
    else:
        print(net1.state_dict()[new_list[i]].shape, new_params[ind_new].shape)
        tmp = new_params[ind_new].reshape((net1.state_dict()[new_list[i]]).shape)
        
    dict_new[ new_list[i] ] = Tensor(tmp)
    print('new:', dict_new[ new_list[i] ].shape)   
    ind_new += 1

net1.load_state_dict(dict_new)
torch.save(net1.state_dict(), 'pretrain_bn_params.pkl')
print('saved done.')


# use pretrained params 

net2 = bninception()
net2.last_linear = nn.Linear(1024, 400)
net2.load_state_dict(torch.load('pretrain_bn_params.pkl'))
print('loaded done.')
# can be changed like

def load_pre_model_dict(self, state_dict):
    own_state = self.state_dict()
    for name, param in state_dict.items():
        if name[:6] == "module":
            name = '.'.join(name.split('.')[2:])
            # print('name,',name)
        if name not in own_state:
            continue
        print('load....', name)
        if isinstance(param, nn.Parameter):
            print('true')
            # backwards compatibility for serialized parameters
            param = param.data
        own_state[name].copy_(param)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值