3、pth转onnx
我们根据上面的mnist.pth结构,自己来构造一个模型:
import torch
import torch.nn as nn
import torch.nn.functional as F
class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__()
self.conv1 = nn.Conv2d(in_channels=1,out_channels=6,kernel_size=3,stride=1,padding=0)
self.conv2 = nn.Conv2d(in_channels=6,out_channels=16,kernel_size=3,stride=1,padding=0)
self.fc1 = nn.Linear(400, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
out = self.conv1(x) # torch.Size([1, 6, 26, 26])
out = F.max_pool2d(F.relu(out), 2) # [1, 6, 13, 13]
out = self.conv2(out) # [1, 16, 11, 11]
out = F.max_pool2d(F.relu(out), 2) # [1, 16, 5, 5]
out = out.view(out.size(0), -1) # [1, 400]
out = self.fc1(ou