import torch
import torch.nn as nn
class Module(nn.Module):
def __init__(self):
super(Module, self).__init__()
self.conv = nn.Conv2d(1, 1, kernel_size=3, stride=1, padding=1, bias=True)
self.conv.weight.data.fill_(0.3)
self.conv.bias.data.fill_(0.2)
def forward(self, x):
x = self.conv(x)
return x.view(-1, int(x.numel()//x.size(0)))
model = Module().eval()
x = torch.full((1, 1, 3, 3), 1.0)
y = model(x)
torch.onnx.export(model, (x,), "lesson1.onnx", verbose=True)
可以在线查看onnx的结构的网址:链接