PyTorch模型转caffe

本文介绍了如何将训练好的PyTorch模型转换为Caffe格式,包括保存PyTorch模型参数,创建对应的Caffe prototxt文件,建立caffemodel文件并进行模型测试,确保转换前后性能一致。
摘要由CSDN通过智能技术生成

1.将Pytorch 模型参数名和对应权重保留,存成字典,存入npy文件

训练好的模型文件参数权重可以保存在‘.pth’文件中,从该文件读取参数权重的数据:

premodel='xxxx.pth'
param_dict={
   }
pretrained_dict=torch.load(premodel,map_location='cpu')
if 'state_dict'==pretrained_dict['state_dict']:	
	for layer,value in pretrained_dict.items():
		layer=str(layer)
		param_dict[layer]=value.detach()
else:
	pass
	
np.save('xxx.npy',param_dict)

2. 建立caffe的prototxt文件。对应pytorch的网络结构,参数名字要有对应规律,可以使用python接口写,然后自动生成。

1.手写prototxt文件,根据pytorch的模型。

推荐caffe模型可视化软件Netron,可以可视化进行改进。了解caffe的基本层和格式的基本形式。

2.采用python接口自动生成prototxt文件

示例:

import caffe 
from pylab import *
import caffe.layers as L
import caffe.params as P
def net():
   n=caffe.NetSpec()
   n.data,n.label=L.Data(source=dbfile,backen=xxx.LMDB, batch_size=batch_size, ntop=2, transform_param=dict(scale=0.00390625))#数据层 
   n.ip1=L.InnerProduct(n.data,num_output=500,weight_file=dict(type='xavier'))#全连接层 ip1是层的name
   n.relu1=L.ReLU(n.ip1,in_place=True)
   n.ip2=L.InnerProduct(n.relu1, num_output=10, weight_filler=dict(type='xavier'))
   n.loss= L.SoftmaxWithLoss(n.ip2, n.label)
   n.accu= L.Accuracy(n.ip2, n.label, include={
   'phase':caffe.TEST})
   return n.to_proto()
with open( 'auto_train00.prototxt', 'w') as f:
    f.write(str(net( '/home/hbk/caffe/examples/mnist/mnist_train_lmdb', 64)))
with open('auto_test00.prototxt', 'w') as f:
    f.write(str(net('/home/hbk/caffe/examples/mnist/mnist_test_lmdb', 100)))
#进行训练的solver 训练参数的填写
solver=caffe.SGDSovlver('hbk_mnist_solver_py.prototxt')
solver.test_nets[0].forward()

solver.step(1)
solver.solve()
   

solver.prototxt的代码示例。具体的参数说明可以自行搜索

# The train/test net 文件路径
train_net: "auto_train00.prototxt"
test_net: "auto_test00.prototxt"
# test_iter specifies how many forward passes the test should carry out.
# In the case of MNIST, we have test batch size 100 and 100 test iterations,
# covering the full 10,000 testing images.
test_iter: 100

# 训练迭代多少次执行一次Test验证
test_interval: 500

# The base learning rate, momentum and the weight decay of the network.
base_lr: 0.01
momentum: 0.9
weight_decay: 0.0005

# The learning rate policy
lr_policy: "inv"
gamma: 0.0001
power: 0.75

# 多少次迭代输出一次信息
display: 100
# The maximum number of iterations
max_iter: 10001
# 存储中间结果
snapshot: 5000
snapshot_prefix: "snapshot"

# solver mode: CPU or GPU
solver_mode: GPU

2,自己手写,注意网络名字,采用和pytorch形式匹配的形式

以数据层为例,和上面的Python代码相对应。


                
  • 1
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值