def make_dot(var, params=None):
""" Produces Graphviz representation of PyTorch autograd graph
Blue nodes are the Variables that require grad, orange are Tensors
saved for backward in torch.autograd.Function
Args:
var: output Variable
params: dict of (name, Variable) to add names to node that
require grad (TODO: make optional)
"""
if params is not None:
assert isinstance(params.values()[0], Variable)
param_map = {id(v): k for k, v in params.items()}
node_attr = dict(style='filled',
shape='box',
align='left',
fontsize='12',
ranksep='0.1',
height='0.2')
dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
seen = set()
def size_to_str(size):
return '('+(', ').
pytorch 画network图,并保存pdf。可用于debug和模型架构调整梳理。
最新推荐文章于 2023-11-10 19:58:04 发布
本文介绍如何使用PyTorch绘制神经网络结构图,并将其保存为PDF文件,便于模型调试和架构优化。
摘要由CSDN通过智能技术生成