一、
import torch
import torchvision.models
import hiddenlayer as hl
model = VGG()##VGG换成对应的自己的网络结构,但是前面需要全部的网络结构
hl.build_graph(model, torch.zeros([1, 3,480, 480]))#自定义一个初始输入
二、
model = VGG()
x = torch.randn(1,3,480,480)#change 12 to the channel number of network input
y = model(x)
# g = make_dot(y)
# g.view()
make_dot(y, params=dict(list(model.named_parameters())))
这部分跟一需要的东西一样,只是是另一种显示方式
在notebook中执行的