Pytorch转Caffe最简单方法

 

由于需要移植模型到比特大陆,华为昇腾这些平台。他们基本都支持caffe的模型,对其他模型支持不太好。用其他方法pytorch转caffe不然就是绕道太多,不然就是很多坑。这里记录一个最简单的方法:

[作者环境: torch 1.2.0 torchvision 0.4.0 ]

pip install pytorch2caffe

import torch
import torchvision
from pytorch2caffe import pytorch2caffe
def SaveDemo():
    from torchvision.models import resnet

    name = 'resnet18'
    resnet18 = resnet.resnet18()
    resnet18.eval()
    dummy_input = torch.ones([1, 3, 224, 224])
    pytorch2caffe.trans_net(resnet18, dummy_input, name)
    pytorch2caffe.save_prototxt('{}.prototxt'.format(name))
    pytorch2caffe.save_caffemodel('{}.caffemodel'.format(name))


if __name__ == '__main__':
    SaveDemo()

如果你的模型中使用了avg_pool 使用这种写法:

x = F.avg_pool2d(x,7)

 

 

 

  • 0
    点赞
  • 18
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值