《自学习》keras_multi_head的使用方法
import keras
from keras_multi_head import MultiHead
model = keras.models.Sequential()
model.add(keras.layers.Embedding(input_dim=100, output_dim=20, name='Embedding'))
model.add(MultiHead(keras.layers.LSTM(units=32), layer_num=5, name='Multi-LSTMs'))
model.add(keras.layers.Flatten(name='Flatten'))
model.add(keras.layers.Dense(units=4, activation='softmax', name='Dense'))
model.build()
model.summary()
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
Embedding (Embedding) (None, None, 20) 2000
_________________________________________________________________
Multi-LSTMs (MultiHead) (None, 3