在Windows10+Anaconda3环境下利用Caffe中的draw_net.py文件绘制Caffe中的网络结构图
1.必要安装
第一步:安装pydotplus。在cmd命令窗口输入:pip install pydotplus
。
第二步:安装Graphviz-2.38。在http://www.graphviz.org/Download_windows.php
页面中下载Graphviz安装包,并完成安装(例如安装目录:C:/Program Files (x86)/Graphviz2.38/bin/);右击“此电脑”打开“属性”,进入系统变量,添加C:/Program Files (x86)/Graphviz2.38/bin/到path中。
2.修改draw_net.py和draw.py文件
(1)打开caffe的根目录下python/caffe/draw.py文件,修改get_layer_label函数中的elif layer.type == 'Pooling':
以下的内容,修改如下:
def get_layer_label(layer, rankdir):
"""Define node label based on layer type.
Parameters
----------
layer : ?
rankdir : {'LR', 'RL', 'TB', 'BT'}
Direction of graph layout.
Returns
-------
string :
A label for the current layer
"""
if rankdir in ('TB', 'BT'):
# If graph orientation is vertical, horizontal space is free and
# vertical space is not; separate words with spaces
separator = ' '
else:
# If graph orientation is horizontal, vertical space is free and
# horizontal space is not; separate words with newlines
separator = '\\n'
if layer.type == 'Convolution' or layer.type == 'Deconvolution':
# Outer double quotes needed or else colon characters don't parse
# properly
node_label = '"%s%s(%s)%skernel size: %d%sstride: %d%spad: %d"' %\
(layer.name,
separator,
layer.type,
separator,
layer.convolution_param.kernel_size[0] if len(layer.convolution_param.kernel_size._values) else 1,
separator,
layer.convolution_param.stride[0] if len(layer.convolution_param.stride._values) else 1,
separator,
layer.convolution_param.pad[0] if len(layer.convolution_param.pad._values) else 0)
elif layer.type == 'Pooling':
pooling_types_dict = get_pooling_types_dict()
node_label = '"%s%s(%s %s)%skernel size: %d%sstride: %d%spad: %d"' %\
(layer.name,
separator,
pooling_types_dict[layer.pooling_param.pool],
layer.type,
separator,
layer.pooling_param.kernel_size,
separator,
layer.pooling_param.stride,
separator,
layer.pooling_param.pad)
else:
node_label = '"%s%s(%s)"' % (layer.name, separator, layer.type)
return node_label
(2)若在pycharm下不使用命令窗口方式,则编辑caffe的根目录下python/draw_net.py文件,修改main函数如下:
def main():
"""
input_net_proto_file: input network prototxt file.
output_image_file: output image file
phase: Which network phase to draw: can be TRAIN, TEST, or ALL. Default=ALL.
If ALL, then all layers are drawn regardless of phase.
rankdir: One of TB (top-bottom, i.e., vertical)、BT(bottom-top, i.e., vertical),
LR (left-right, i.e., horizontal), RL (right-left, i.e., horizontal);
"""
# the file of the net prototxt
input_net_proto_path = '../examples/mnist/'
input_net_proto_name = 'lenet_train_test.prototxt'
input_net_proto_file = input_net_proto_path + input_net_proto_name
net = caffe_pb2.NetParameter()
text_format.Merge(open(input_net_proto_file).read(), net)
rankdir = 'LR'
# draw the TRAIN or TEST phase network
phase = "ALL"
if phase == "TRAIN":
phase = caffe.TRAIN
elif phase == "TEST":
phase = caffe.TEST
elif phase == "ALL":
phase = None
# the path to a file where the networks visualization will be stored
output_net_proto_path = '../examples/mnist/'
output_net_proto_file = output_net_proto_path + 'Lenet_' + str(phase) + '.jpg'
draw.draw_net_to_file(net, output_net_proto_file, rankdir, phase)
**input_net_proto_file
:指prototxt文件
output_image_file
:指保存后的图片名称和格式**
若在pycharm下使用命令窗口方式,则在下面界面中
输入命令:python ./python/draw_net.py ./examples/mnist/lenet_train_test.prototxt lenet.jpg
也可以指定图片的保存位置:python ./python/draw_net.py ./examples/mnist/lenet_train_test.prototxt ./examples/mnist/lenet.jpg
,图片保存在examples/mnist/目录下。
在上述命令中–rankdir默认值为LR,–phase默认值为ALL。