TorchScript学习使用

加载一个训练好的mnist模型:

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary

class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.mp = nn.MaxPool2d(2)
        self.fc = nn.Linear(320, 10)

    def forward(self, x):
        in_size = x.size(0)
        x = F.relu(self.mp(self.conv1(x)))
        x = F.relu(self.mp(self.conv2(x)))
        x = x.view(in_size, -1)  # flatten the tensor
        x = self.fc(x)
        x = F.log_softmax(x,dim=0)
        # print(x[0])
        return x


model = Net()

filedir = os.path.join('model','mnistcnn.pt' )
torch.load( filedir)

# 打印 model参数
# print(model )
# summary(model.cuda(), (1,28,28) )
# print(model.conv1.state_dict())

dummy_input = torch.rand(1, 1, 28, 28)
# IR生成
with torch.no_grad():
    jit_model = torch.jit.trace(model, dummy_input)
    jit_model.save('model/jit_model.pth')

jit_layer1 = jit_model
# print(jit_layer1.graph)
torch._C._jit_pass_inline(jit_layer1.graph)
# print(jit_layer1.code)

load_jit_model = torch.jit.load('model/jit_model.pth')
print(load_jit_model.code)

# C++
# // 加载生成的torchscript模型
# auto module = torch::jit::load('jit_model.pth');
# // 根据任务需求读取数据
# std::vector<torch::jit::IValue> inputs = ...;
# // 计算推理结果
# auto output = module.forward(inputs).toTensor();

reference:

TorchScript 解读(一):初识 TorchScript - 知乎

https://www.jianshu.com/p/a94d49351e05

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值