可以使用 tensorflow 的 tf.keras.utils.plot_model
函数来显示模型结构。
示例代码如下:
import tensorflow as tf# 建立模型
model = tf.keras.Sequential()
model.add(tf.keras.layers.Dense(64, activation='relu'))
model.add(tf.keras.layers.Dense(64, activation='relu'))
model.add(tf.keras.layers.Dense(10, activation='softmax'))
# 显示模型结构
tf.keras.utils.plot_model(model, 'model.png', show_shapes=True)
这段代码会生成一个名为 "model.png" 的图片,显示出模型的结构。
建议在colab中运行,并且在运行之前先安装pydot和graphviz,可以使用 !pip install pydot graphviz 命令安装。