Keras 模型可视化
model.summary()
可以查看基本情况Sequential
使用summary()
基本没问题,但是模型如果复杂多变,summary
方法无法表示模型的空间结构- 介绍Kera的
keras.utils.plot_model
方法,优点在于:
安装需要的环境
- pyplot-ng
- graphviz
- 本机是Centos,用 yum install graphviz
- Ubuntu,应该是 apt-get install graphviz
示例
build_model
建立一个Seq2Seq(相当复杂的模型)- 使用
plot_model
生成模型结构的图片,结构清楚,很棒 summary
方法完全看不出模型的空间结构
import random
import numpy as np
from keras import layers
from keras.layers import Input, Embedding, Bidirectional, Dense, Concatenate, LSTM
from keras.models import Model, load_model
from keras.utils import plot_model
def build_model():
rnn = layers.LSTM
num_encoder_tokens = 20
num_decoder_tokens = 100
encoder_embedding_dim = 20
decoder_embedding_dim = 100
latent_dim = 256
encoder_inputs = Input(shape=(None,), name='encoder_inputs')
encoder_embedding = Embedding(num_encoder_tokens, encoder_embedding_dim,name='encoder_embedding')(encoder_inputs)
bidi_encoder_lstm = Bidirectional(rnn(latent_dim, return_state=True, dropout=0.2,recurrent_dropout=0.5), name='encoder_lstm')
_, forward_h, forward_c, backward_h, backward_c = bidi_encoder_lstm(encoder_embedding)
state_h = Concatenate()([forward_h, backward_h])
state_c = Concatenate()([forward_c, backward_c])
encoder_states = [state_h, state_c]
decoder_inputs = Input(shape=(None,), name='decoder_inputs')
decoder_embedding = Embedding(num_decoder_tokens, decoder_embedding_dim, name='decoder_embedding')(decoder_inputs)
decoder_lstm = rnn(latent_dim*2, return_state=True,
return_sequences=True, dropout=0.2,
recurrent_dropout=0.5, name='decoder_lstm')
rnn_outputs, *decoder_states = decoder_lstm(decoder_embedding, initial_state=encoder_states)
decoder_dense = Dense(num_decoder_tokens, activation='softmax', name='decoder_dense')
decoder_outputs = decoder_dense(rnn_outputs)
bidi_encoder_model = Model([encoder_inputs,decoder_inputs], [decoder_outputs])
bidi_encoder_model.compile(optimizer='adam', loss='categorical_crossentropy')
return bidi_encoder_model
model = build_model()
plot_model(model, to_file='seq2seq_model.png', show_shapes=True)
model.summary()
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
encoder_inputs (InputLayer) (None, None) 0
__________________________________________________________________________________________________
encoder_embedding (Embedding) (None, None, 20) 400 encoder_inputs[0][0]
__________________________________________________________________________________________________
decoder_inputs (InputLayer) (None, None) 0
__________________________________________________________________________________________________
encoder_lstm (Bidirectional) [(None, 512), (None, 567296 encoder_embedding[0][0]
__________________________________________________________________________________________________
decoder_embedding (Embedding) (None, None, 100) 10000 decoder_inputs[0][0]
__________________________________________________________________________________________________
concatenate_1 (Concatenate) (None, 512) 0 encoder_lstm[0][1]
encoder_lstm[0][3]
__________________________________________________________________________________________________
concatenate_2 (Concatenate) (None, 512) 0 encoder_lstm[0][2]
encoder_lstm[0][4]
__________________________________________________________________________________________________
decoder_lstm (LSTM) [(None, None, 512), 1255424 decoder_embedding[0][0]
concatenate_1[0][0]
concatenate_2[0][0]
__________________________________________________________________________________________________
decoder_dense (Dense) (None, None, 100) 51300 decoder_lstm[0][0]
==================================================================================================
Total params: 1,884,420
Trainable params: 1,884,420
Non-trainable params: 0
__________________________________________________________________________________________________