加载一个训练好的mnist模型:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.mp = nn.MaxPool2d(2)
self.fc = nn.Linear(320, 10)
def forward(self, x):
in_size = x.size(0)
x = F.relu(self.mp(self.conv1(x)))
x = F.relu(self.mp(self.conv2(x)))
x = x.view(in_size, -1) # flatten the tensor
x = self.fc(x)
x = F.log_softmax(x,dim=0)
# print(x[0])
return x
model = Net()
filedir = os.path.join('model','mnistcnn.pt' )
torch.load( filedir)
# 打印 model参数
# print(model )
# summary(model.cuda(), (1,28,28) )
# print(model.conv1.state_dict())
dummy_input = torch.rand(1, 1, 28, 28)
# IR生成
with torch.no_grad():
jit_model = torch.jit.trace(model, dummy_input)
jit_model.save('model/jit_model.pth')
jit_layer1 = jit_model
# print(jit_layer1.graph)
torch._C._jit_pass_inline(jit_layer1.graph)
# print(jit_layer1.code)
load_jit_model = torch.jit.load('model/jit_model.pth')
print(load_jit_model.code)
# C++
# // 加载生成的torchscript模型
# auto module = torch::jit::load('jit_model.pth');
# // 根据任务需求读取数据
# std::vector<torch::jit::IValue> inputs = ...;
# // 计算推理结果
# auto output = module.forward(inputs).toTensor();
reference: