转发:https://blog.csdn.net/gyguo95/article/details/78821617
首先要安装
graphviz
-
- 这种方法需要安装
python-graphviz
:conda install -n pytorch python-graphviz
- 这种方法需要安装
visualize.py
from graphviz import Digraph
import torch
from torch.autograd import Variable
def make_dot(var, params=None):
""" Produces Graphviz representation of PyTorch autograd graph
Blue nodes are the Variables that require grad, orange are Tensors
saved for backward in torch.autograd.Function
Args:
var: output Variable
params: dict of (name, Variable) to add names to node that
require grad (TODO: make optional)
"""
if params is not None:
assert isinstance(params.values()[0], Variable)
param_map = {id(v): k for k, v in params.items()}
node_attr = dict(style='filled',
shape='box',
align='left',
fontsize='12',
ranksep='0.1',
height='0.2')
dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
seen = set()
def size_to_str(size):
return '('+(', ').join(['%d' % v for v in size])+')'
def add_nodes(var):
if var not in seen:
if torch.is_tensor(var):
dot.node(str(id(var)), size_to_str(var.size()), fillcolor='orange')
elif hasattr(var, 'variable'):
u = var.variable
name = param_map[id(u)] if params is not None else ''
node_name = '%s\n %s' % (name, size_to_str(u.size()))
dot.node(str(id(var)), node_name, fillcolor='lightblue')
else:
dot.node(str(id(var)), str(type(var).__name__))
seen.add(var)
if hasattr(var, 'next_functions'):
for u in var.next_functions:
if u[0] is not None:
dot.edge(str(id(u[0])), str(id(var)))
add_nodes(u[0])
if hasattr(var, 'saved_tensors'):
for t in var.saved_tensors:
dot.edge(str(id(t)), str(id(var)))
add_nodes(t)
add_nodes(var.grad_fn)
return dot
import ResNet34
import numpy as np
import torch
from torch.autograd import Variable
from visualize import make_dot
from ResNet34 import NetG
import torch as t
class Config(object):
nz = 500 # 噪声维度
ngf = 64 # 生成器feature map数
ndf = 64 # 判别器feature map数
gen_search_num = 3 # 从512张生成的图片中保存最好的64张
g_every = 5 # 每5个batch训练一次生成器
gen_mean = 0 # 噪声的均值
gen_std = 2 # 噪声的方差
gen_num = 1
batch_size = 256
gpu = False # 是否使用GPU
gen_img = '2018.png'
if __name__ == '__main__':
opt = Config()
a = NetG(opt)
noises = t.randn(opt.gen_search_num, opt.nz, 1, 1).normal_(opt.gen_mean, opt.gen_std)
noises = Variable(noises, volatile=True)
y = a(noises)
print(y.size())
g = make_dot(y)
g.view()
#g.render('here', view=False)