tensorflow2 中绘制网络模型结构图
主题:pydot,环境tensorflow2.0
- 安装pydot之前,要事先安装anaconda;
- conda命令行当中执行命令,它会自动去下载pydot和graphviz等相关依赖包的
conda install pydot
或者使用清华镜像源,速度快
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple pydot
- 安装完成后,执行命令查看已安装的包,找一下有没有pydot包,判断是否安装成功
conda list
- 如果这个时候,你已经打开了pycharm或者jupyter notebook等编译器,最好重启一下,这样能够自动刷新安装的包,才可以正常使用
关于pydot的使用
1. 语法参数
tf.keras.utils.plot_model(
model,
to_file='model.png',
show_shapes=False,
show_layer_names=True,
rankdir='TB',
expand_nested=False,
dpi=96
)
2. 使用方式
在创建神经网络模型model后,一般是通过mode.summary()可以查看网络的模型,如果模型太长太复杂的时候,这种方式看起来不是很直观,可以在模型model创建完成后,使用
导包+数据加载代码
模型创建代码 (以简单的线性层为例)
model = tf.keras.Sequential([
tf.keras.layers.Dense(10, input_shape=(3,), activation='relu'),
tf.keras.layers.Dense(1)
])
tf.keras.utils.plot_model(model, to_file='dense.png', show_shapes=True, show_layer_names=True)
在你的程序文件目录下会有一张dense.png,这就是你创建的模型结构
画出的网络结构图还是比较美观的