神经网络的结构图一般在论文里是不可少的一部分内容,把代码组成的神经网络展现在论文里还是一件挺麻烦的事情,自己用PPT画半天费时费力不说,遇到多维张量更是不知道怎么画了。今天介绍一个python包——visual keras可以辅助我们完成这部分工作。
从名字也可以看出来了,visual keras适合使用kears API构建的神经网络,并可视化它的结构,不过好像使用tensorflow构建的网络也可以。同时支持网络层风格的可视化和节点风格的可视化,这两种风格在后面会进行演示。
下面介绍使用方法:
首先安装依赖
pip install visualkeras
出现successfully installed visualkeras-0.0.2就算成功了,注意一下这里的0.0.2版本,这是目前最新的版本,旧版本有不少问题。
下面使用keras定义一个模型试试
这里可以说明一下,这个包同时支持sequential model和function model。随便用就可以了。
model = models.Sequential()
model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
visual keras比较方便的地方就是模型不用编译,不用训练,只需要给出模型结构代码就行了。
构建好模型,直接用以下代码就可以绘制结构图了,下图分别是层风格图和节点风格图。
visualkeras.layered_view(model)
visualkeras.graph_view(model)
visualkeras.layered_view(model,legend=True)
legend可以显示标签,visualkeras.layered_view函数还有一个font参数可以自定义标签的字体,不过font要自己先定义,有点麻烦就不弄了,稍微注意一下的就是使用font要先导入from PIL import ImageFont。
这个是加上全连接层和flatten层的样子
对于循环神经网络和一维卷积网络或者双向循环网络都没有什么不同,直接按上述方法使用就行。 不过使用循环神经网络经常会出现这个 int object is not iterable的问题,这里就要注意你的visual keras的版本了,一般都是版本问题
使用下面代码更新到最新版本就可以了。如果有其他问题也可以去作者的githubhttps://github.com/paulgavrikov/visualkeras/issues看一下issues里面有十几个比较常见的问题,有助于解决使用过程的很多问题。这个int对象没法迭代的问题有提问者和作者battle了好几个来回最后发现确实是自己没有更新版本。更新完了还不行的话,可以尝试先uninstall visualkeras再安装,还不行的话记得重启一次整个工作环境,要不然更过了可能没反应。
pip install git+https://github.com/paulgavrikov/visualkeras --upgrade
Collecting git+https://github.com/paulgavrikov/visualkeras
下面还有一些其他技巧是把3D图转为2D图,更换层级颜色,自定义每个层之间的距离
//2D显示模型结构
visualkeras.layered_view(model, legend=True, font=font, draw_volume=False)
//将层级之间的距离设置为50
visualkeras.layered_view(model, legend=True, font=font, draw_volume=False,spacing=50)
//更换每一个层的颜色
from collections import defaultdict
color_map = defaultdict(dict)
color_map[layers.Conv2D]['fill'] = 'orange'
color_map[layers.MaxPooling2D]['fill'] = 'red'
color_map[layers.Dense]['fill'] = 'black'
color_map[layers.Flatten]['fill'] = 'teal'
visualkeras.layered_view(model, legend=True, font=font,color_map=color_map)
这里可以展示一个VGG16网络二维的表示,其他的功能就不一一展示了。