1. 支持的转换算子
github上实现的PytorchToCaffe的代码,支持转换的算子如下(参见:pytorch_to_caffe.py
):
F.conv2d=Rp(F.conv2d,_conv2d)
F.linear=Rp(F.linear,_linear)
F.relu=Rp(F.relu,_relu)
F.leaky_relu=Rp(F.leaky_relu,_leaky_relu)
F.max_pool2d=Rp(F.max_pool2d,_max_pool2d)
F.avg_pool2d=Rp(F.avg_pool2d,_avg_pool2d)
F.adaptive_avg_pool2d = Rp(F.adaptive_avg_pool2d,_adaptive_avg_pool2d)
F.dropout=Rp(F.dropout,_dropout)
F.threshold=Rp(F.threshold,_threshold)
F.prelu=Rp(F.prelu,_prelu)
F.batch_norm=Rp(F.batch_norm,_batch_norm)
F.instance_norm=Rp(F.instance_norm,_instance_norm)
F.softmax=Rp(F.softmax,_softmax)
F.conv_transpose2d=Rp(F.conv_transpose2d,_conv_transpose2d)
F.interpolate = Rp(F.interpolate,_interpolate)
F.sigmoid = Rp(F.sigmoid,_sigmoid)
F.tanh = Rp(F.tanh,_tanh)
F.tanh = Rp(F.tanh,_tanh)
F.hardtanh = Rp(F.hardtanh,_hardtanh)
# F.l2norm = Rp(F.l2norm,_l2Norm)
torch.split=Rp(torch.split,_split)
torch.max=Rp(torch.max,_max)
torch.cat=Rp(torch.cat,_cat)
torch.div=Rp(torch.div,_div)
- 作者重写了caffe的算子,来
替换
orch.nn算子。其中RP
表示替换的意思(Replace) - 主要支持转Caffe的算子包括:
F.conv2d,F.linear,F.relu,F.leaky_relu,F.max_pool2d,F.avg_pool2d,F.adaptive_avg_pool2d,F.dropout,F.threshold,F.prelu,F.batch_norm,F.instance_norm,F.softmax,F.conv_transpose2d,F.interpolate
等 F.upsample
和F.interpolate
算子不支持,经过测试
上采样操作建议使用F.conv_transpose2d
转置卷积替换。其中F.interpolate
算子在转换caffe模型时,容易提示upsample_h
参数不存在的错误
(虽然作者代码中显示支持F.interpolate)。
2. pytoch转Caffe
- (1) : github上下载PytorchToCaffe的脚本。
- (2): 将
Caffe
文件夹和pytorch_to_caffe.py
文件放到项目根目录 - (3): 对项目中不支持转caffe的算子,如
upsample
和F.interpolate
,使用F.conv_transpose2d
替换。 - (4): 替换后重新训练pytorch模型,获得训练好的model.pt文件
- (5): 在项目跟目录上创建
convertCaffe.py
,利用训练好的.pt
文件,转caffe的.prototxt
和.caffemodel
模型文件。convertCaffe.py
的代码实现如下:
import sys
sys.path.insert(0,'.')
import torch
from torch.autograd import Variable
from torchvision.models import resnet
import pytorch_to_caffe
from nets.deeplabv3_plus import DeepLab
if __name__=='__main__':
name = 'deeplab'
model = DeepLab(8, backbone="mobilenet", downsample_factor=16, pretrained=False)
#model.load_state_dict(torch.load('logs/best_epoch_weights.pth', map_location='cpu'))
checkpoint = torch.load("logs/best_epoch_weights.pth")
model.load_state_dict(checkpoint,False)
model.eval()
input=torch.ones([1,3,224,224])
#input=torch.ones([1,3,224,224])
pytorch_to_caffe.trans_net(model,input,name)
pytorch_to_caffe.save_prototxt('{}.prototxt'.format(name))
pytorch_to_caffe.save_caffemodel('{}.caffemodel'.format(name))
转换成功会提示Transform Completed