- 打印输出图像的维度变化
如果没有安装summary,可以按照这个指令安装:
pip install torchsummary
打印每一层特征维度的代码:
import torch
import torchvision
from torchsummary import summary #使用 pip install torchsummary
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
vgg = torchvision.models.vgg16().to(device)
summary(vgg, input_size=(3, 224, 224))
结构如下:
- 打印网络的可视化图结构
先看最终的结果:
这个图像化,就很简单明了,但是还有一点点不方便的地方,就是网络模型一般都比较大,不能显示完全,或者字体太小,查看或者打印不太方便。建议每一个图按照A4纸尺寸来设计,里面的字体也可以调整。这样可以方便打印。
使用方法:
需要安装一个库文件:graphviz
pip install graphviz
使用代码:
from graphviz import Digraph
import torch
from torch.autograd import Variable
def make_dot(var, params=None):
""" Produces Graphviz representation of PyTorch autograd graph
Blue nodes are the Variables that require grad, orange are Tensors
saved for backward in torch.autograd.Function
Args:
var: output Variable
params: dict of (name, Variable) to add names to node that
require grad (TODO: make optional)
"""
if params is not None:
assert isinstance(params.values()[0], Variable)
param_map = {id(v): k for k, v in params.items()}
node_attr = dict(style='filled',
shape='box',
align='left',
fontsize='12',
ranksep='0.1',
height='0.2')
dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
seen = set()
def size_to_str(size):
return '('+(', ').join(['%d' % v for v in size])+')'
def add_nodes(var):
if var not in seen:
if torch.is_tensor(var):
dot.node(str(id(var)), size_to_str(var.size()), fillcolor='orange')
elif hasattr(var, 'variable'):
u = var.variable
name = param_map[id(u)] if params is not None else ''
node_name = '%s\n %s' % (name, size_to_str(u.size()))
dot.node(str(id(var)), node_name, fillcolor='lightblue')
else:
dot.node(str(id(var)), str(type(var).__name__))
seen.add(var)
if hasattr(var, 'next_functions'):
for u in var.next_functions:
if u[0] is not None:
dot.edge(str(id(u[0])), str(id(var)))
add_nodes(u[0])
if hasattr(var, 'saved_tensors'):
for t in var.saved_tensors:
dot.edge(str(id(t)), str(id(var)))
add_nodes(t)
add_nodes(var.grad_fn)
return dot
if __name__ == '__main__':
from models import KFSGNet
from torch.autograd import Variable
import torch
net = KFSGNet()
x = Variable(torch.randn((1,1,96,96)))
y = net(x)
g = make_dot(y)
g.view()
pass
- 直接打印定义好的模型变量
net = VGGNet()
print(net)
显示结果:
- 打印模块名字和参数的大小
# net.named_parameters() 也是可迭代对象,既能调出网络的具体参数,也有名字信息
for name, parameters in net.named_parameters():
print(name, ';', parameters.size())