Pytorch模型转Caffe

1. 支持的转换算子

github上实现的PytorchToCaffe的代码,支持转换的算子如下(参见:pytorch_to_caffe.py):

F.conv2d=Rp(F.conv2d,_conv2d)
F.linear=Rp(F.linear,_linear)
F.relu=Rp(F.relu,_relu)
F.leaky_relu=Rp(F.leaky_relu,_leaky_relu)
F.max_pool2d=Rp(F.max_pool2d,_max_pool2d)
F.avg_pool2d=Rp(F.avg_pool2d,_avg_pool2d)
F.adaptive_avg_pool2d = Rp(F.adaptive_avg_pool2d,_adaptive_avg_pool2d)
F.dropout=Rp(F.dropout,_dropout)
F.threshold=Rp(F.threshold,_threshold)
F.prelu=Rp(F.prelu,_prelu)
F.batch_norm=Rp(F.batch_norm,_batch_norm)
F.instance_norm=Rp(F.instance_norm,_instance_norm)
F.softmax=Rp(F.softmax,_softmax)
F.conv_transpose2d=Rp(F.conv_transpose2d,_conv_transpose2d)
F.interpolate = Rp(F.interpolate,_interpolate)
F.sigmoid = Rp(F.sigmoid,_sigmoid)
F.tanh = Rp(F.tanh,_tanh)
F.tanh = Rp(F.tanh,_tanh)
F.hardtanh = Rp(F.hardtanh,_hardtanh)
# F.l2norm = Rp(F.l2norm,_l2Norm)

torch.split=Rp(torch.split,_split)
torch.max=Rp(torch.max,_max)
torch.cat=Rp(torch.cat,_cat)
torch.div=Rp(torch.div,_div)
  • 作者重写了caffe的算子,来替换orch.nn算子。其中RP表示替换的意思(Replace)
  • 主要支持转Caffe的算子包括:F.conv2d,F.linear,F.relu,F.leaky_relu,F.max_pool2d,F.avg_pool2d,F.adaptive_avg_pool2d,F.dropout,F.threshold,F.prelu,F.batch_norm,F.instance_norm,F.softmax,F.conv_transpose2d,F.interpolate
  • F.upsampleF.interpolate算子不支持,经过测试上采样操作建议使用F.conv_transpose2d转置卷积替换。其中F.interpolate算子在转换caffe模型时,容易提示upsample_h参数不存在的错误(虽然作者代码中显示支持F.interpolate)。

2. pytoch转Caffe

  • (1) : github上下载PytorchToCaffe的脚本。
    在这里插入图片描述
  • (2): 将Caffe文件夹和pytorch_to_caffe.py文件放到项目根目录
  • (3): 对项目中不支持转caffe的算子,如upsampleF.interpolate,使用F.conv_transpose2d替换。
  • (4): 替换后重新训练pytorch模型,获得训练好的model.pt文件
  • (5): 在项目跟目录上创建convertCaffe.py,利用训练好的.pt文件,转caffe的.prototxt.caffemodel模型文件。convertCaffe.py的代码实现如下:
import sys
sys.path.insert(0,'.')
import torch
from torch.autograd import Variable
from torchvision.models import resnet
import pytorch_to_caffe
from nets.deeplabv3_plus import DeepLab

if __name__=='__main__':
    name = 'deeplab'
    model = DeepLab(8, backbone="mobilenet", downsample_factor=16, pretrained=False)
    #model.load_state_dict(torch.load('logs/best_epoch_weights.pth', map_location='cpu'))
    checkpoint = torch.load("logs/best_epoch_weights.pth")
    model.load_state_dict(checkpoint,False)
    model.eval()
    input=torch.ones([1,3,224,224])
     #input=torch.ones([1,3,224,224])
    pytorch_to_caffe.trans_net(model,input,name)
    pytorch_to_caffe.save_prototxt('{}.prototxt'.format(name))
    pytorch_to_caffe.save_caffemodel('{}.caffemodel'.format(name))

在这里插入图片描述
转换成功会提示Transform Completed

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

@BangBang

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值