1、函数的定义及其形参
tf.keras.utils.plot_model(
model,
to_file='model.png',
show_shapes=False,
show_layer_names=True,
rankdir='TB',
expand_nested=False,
dpi=96
)
Arguments:
model
: A Keras model instanceto_file
: File name of the plot image.show_shapes
: whether to display shape information.show_layer_names
: whether to display layer names.rankdir
:rankdir
argument passed to PyDot, a string specifying the format of the plot: 'TB' creates a vertical plot; 'LR' creates a horizontal plot.expand_nested
: Whether to expand nested models into clusters.dpi
: Dots per inch.
2、使用实例
model = keras.Sequential([
layers.LSTM(32, input_shape=[5, 20], return_sequences=True, kernel_regularizer=keras.regularizers.l2(0.01)),
layers.LSTM(32, return_sequences=False),
layers.Dense(3, activation='softmax'),
])
tf.keras.utils.plot_model(model, to_file='model3.png', show_shapes=True, show_layer_names=True,rankdir='TB', dpi=900, expand_nested=True)
3、可能碰到的问题
我在使用过程中,碰到错误如下:
raise ImportError('Failed to import pydot. You must install pydot'
ImportError: Failed to import pydot. You must install pydot and graphviz for `pydotprint` to work.
按照提示安装pydot和graphviz即可,我是用conda进行安装
https://blog.csdn.net/yz19930510/article/details/82345181
https://tensorflow.google.cn/api_docs/python/tf/keras/utils/plot_model