使用tensorwatch
import os
import torch
import tensorwatch as tw
from torchvision import models
os.environ["PATH"] += r'C:\Program Files (x86)\Graphviz2.38\bin' # 安装graphviz时的路径
model = models.inception_v3(pretrained=True)
dummy_input = torch.rand(1, 3, 347, 347)
tw.draw_model(model, dummy_input).save('inception_v3.pdf')
经过测试,这里的dummy_input最小大小为347,比这小的会报以下错误:
ValueError: Expected more than 1 value per channel when training, got input size torch.Size([1, 768, 1, 1])
得到的网络结构如下: