使用netron工具可视化Pytorch模型和输出参数
1 安装netron
pip install netron
2 导入包
import netron
import torch.onnx
程序调用
if __name__ == '__main__':
net = vgg()
x = Variable(torch.FloatTensor(16, 3, 40, 40))
y = net(x)
print(y.data.shape)
onnx_path = "onnx_model_name.onnx"
torch.onnx.export(net, x, onnx_path)
netron.start(onnx_path)
展示图
3 使用 summary 进行参数的输出
导入包:from torchkeras import summary
调用:summary(resnet,(3,224,224)),注意这里的输入不要带上betch_size