PyTorch模型转caffe简单教程
1.将Pytorch 模型参数名和对应权重保留,存成字典,存入npy文件
训练好的模型文件参数权重可以保存在‘.pth’文件中,从该文件读取参数权重的数据:
premodel='xxxx.pth'
param_dict={
}
pretrained_dict=torch.load(premodel,map_location='cpu')
if 'state_dict'==pretrained_dict['state_dict']:
for layer,value in pretrained_dict.items():
layer=str(layer)
param_dict[layer]=value.detach()
else:
pass
np.save('xxx.npy',param_dict)
2. 建立caffe的prototxt文件。对应pytorch的网络结构,参数名字要有对应规律,可以使用python接口写,然后自动生成。
1.手写prototxt文件,根据pytorch的模型。
推荐caffe模型可视化软件Netron,可以可视化进行改进。了解caffe的基本层和格式的基本形式。
2.采用python接口自动生成prototxt文件
示例:
import caffe
from pylab import *
import caffe.layers as L
import caffe.params as P
def net():
n=caffe.NetSpec()
n.data,n.label=L.Data(source=dbfile,backen=xxx.LMDB, batch_size=batch_size, ntop=2, transform_param=dict(scale=0.00390625))#数据层
n.ip1=L.InnerProduct(n.data,num_output=500,weight_file=dict(type='xavier'))#全连接层 ip1是层的name
n.relu1=L.ReLU(n.ip1,in_place=True)
n.ip2=L.InnerProduct(n.relu1, num_output=10, weight_filler=dict(type='xavier'))
n.loss= L.SoftmaxWithLoss(n.ip2, n.label)
n.accu= L.Accuracy(n.ip2, n.label, include={
'phase':caffe.TEST})
return n.to_proto()
with open( 'auto_train00.prototxt', 'w') as f:
f.write(str(net( '/home/hbk/caffe/examples/mnist/mnist_train_lmdb', 64)))
with open('auto_test00.prototxt', 'w') as f:
f.write(str(net('/home/hbk/caffe/examples/mnist/mnist_test_lmdb', 100)))
#进行训练的solver 训练参数的填写
solver=caffe.SGDSovlver('hbk_mnist_solver_py.prototxt')
solver.test_nets[0].forward()
solver.step(1)
solver.solve()
solver.prototxt的代码示例。具体的参数说明可以自行搜索
# The train/test net 文件路径
train_net: "auto_train00.prototxt"
test_net: "auto_test00.prototxt"
# test_iter specifies how many forward passes the test should carry out.
# In the case of MNIST, we have test batch size 100 and 100 test iterations,
# covering the full 10,000 testing images.
test_iter: 100
# 训练迭代多少次执行一次Test验证
test_interval: 500
# The base learning rate, momentum and the weight decay of the network.
base_lr: 0.01
momentum: 0.9
weight_decay: 0.0005
# The learning rate policy
lr_policy: "inv"
gamma: 0.0001
power: 0.75
# 多少次迭代输出一次信息
display: 100
# The maximum number of iterations
max_iter: 10001
# 存储中间结果
snapshot: 5000
snapshot_prefix: "snapshot"
# solver mode: CPU or GPU
solver_mode: GPU
2,自己手写,注意网络名字,采用和pytorch形式匹配的形式
以数据层为例,和上面的Python代码相对应。