网络结构对比
下图上为GRU的网络结构,下为LSTM的网络结构。
不同的是,GRU将LSTM中的遗忘门和输入门合并成了一个重置门,细胞更新状态变为更新门。
LSTM网络介绍链接
参数量对比
从下面对比可以看出GRU参数更少,可以减降低过拟合,训练效率更高。
假设输入为:
TIME_STEPS = 28 # 时间步
INPUT_SIZE = 28 # 每个时间步的特征长度m
CELL_SIZE = 100 # 隐藏神经元个数n
OUTPUT_SIZE = 10 # 输出长度
inputs = Input(shape=[TIME_STEPS,INPUT_SIZE])
LSTM:
x = LSTM(CELL_SIZE, input_shape = (TIME_STEPS,INPUT_SIZE), return_sequences=False)(inputs)
GRU:
x = GRU(CELL_SIZE, input_shape = (TIME_STEPS,INPUT_SIZE), return_sequences=False)(inputs)
输出:
x = Dense(OUTPUT_SIZE)(x)
x = Activation("softmax")(x)
model = Model(inputs,x)
LSTM网络参数:
Layer (type) Output Shape Param #
=================================================================
input_1 (InputLayer) (None, 28, 28) 0
_________________________________________________________________
lstm_1 (LSTM) (None, 100) 51600
_________________________________________________________________
dense_1 (Dense) (None, 10) 1010
_________________________________________________________________
activation_1 (Activation) (None, 10) 0
=================================================================
Total params: 52,610
Trainable params: 52,610
Non-trainable params: 0
GRU网络参数:
Layer (type) Output Shape Param #
=================================================================
input_1 (InputLayer) (None, 28, 28) 0
_________________________________________________________________
gru_1 (GRU) (None, 100) 38700
_________________________________________________________________
dense_1 (Dense) (None, 10) 1010
_________________________________________________________________
activation_1 (Activation) (None, 10) 0
=================================================================
Total params: 39,710
Trainable params: 39,710
Non-trainable params: 0