深度学习之循环神经网络(9)LSTM层使用方法
在TensorFlow中,同样有两种方式实现LSTM网络。既可以使用LSTMCell来手动完成时间戳上面的循环运算,也可以通过LSTM层方式一步完成前向运算。
1. LSTMCell
LSTMCell的用法和SimpleRNNCell基本一致,区别在于LSTM的状态向量List有两个,即 [ h t , c t ] [\boldsymbol h_t,\boldsymbol c_t] [ht,ct],需要分别初始化,其中List第一个元素为 h t \boldsymbol h_t ht,第二个元素为 c t \boldsymbol c_t ct。调用cell完成前向运算时,返回两个元素,第一个元素为cell的输出,也就是 h t \boldsymbol h_t ht,第二个元素为cell的更新后的状态List: [ h t , c t ] [\boldsymbol h_t,\boldsymbol c_t] [ht,ct]。首先新建一个状态向量长度 h = 64 h=64 h=64的LSTM Cell,其中状态向量 c t \boldsymbol c_t ct和输出向量 h t \boldsymbol h_t ht的长度都为 h h h,代码如下:
import tensorflow as tf
from tensorflow.keras import layers
x = tf.random.normal([2, 80, 100])
xt = x[:, 0, :] # 得到一个时间戳的输入
print(xt.shape) # 查看xt,即输入的shape
cell = layers.LSTMCell(64) # 创建LSTM Cell
# 初始化状态和输出List,[h,c]
state = [tf.zeros([2, 64]), tf.zeros([2, 64])]
out, state = cell(xt, state) # 前向计算
# 查看输出out,h以及c的shape
print(out.shape, state[0].shape, state[1].shape)
# 查看返回元素的id
print(id(out), id(state[0]), id(state[1]))
运行结果如下所示:
(2, 100)
(2, 64) (2, 64) (2, 64)
140455569736272 140455569736272 140455569737504
可以看到,返回的输出out和List的第一个元素
h
t
\boldsymbol h_t
ht的id是相同的,这与基础的RNN初衷一致,都是为了格式的统一。
通过在时间戳上展开循环运算,即可完成一次层的前向传播,写法与基础的RNN一样。例如:
# 在序列长度维度上解开,循环送入LSTM Cell单元
for xt in tf.stack(x, axis=1):
# 前向计算
out, state = cell(xt, state) # 前向计算
输出共80个循环。
输出可以仅使用最后一个时间戳上的输出,也可以聚合所有时间戳上的输出向量。
2. LSTM层
通过layers.LSTM层可以方便地一次完成整个序列的运算。首先新建LSTM网络层,例如:
import tensorflow as tf
from tensorflow.keras import layers
x = tf.random.normal([2, 80, 100])
xt = x[:, 0, :] # 得到一个时间戳的输入
# 创建一层LSTM层,内存向量长度为64
layer = layers.LSTM(64)
# 序列通过LSTM层,默认返回最后一个时间戳的输出h
out = layer(x)
print(out)
运行结果如下所示:
tf.Tensor(
[[ 3.95199209e-02 -1.25364646e-01 2.94331431e-01 -1.14304878e-01
-7.70256370e-02 1.86939016e-01 1.09123811e-01 -1.70704648e-01
-6.79743290e-02 -5.36438338e-02 -1.46275446e-01 -1.71275020e-01
-2.45340884e-01 -2.98208650e-02 -3.66737805e-02 -1.68236911e-01
-1.98541462e-01 -2.34785363e-01 -2.27337614e-01 -7.58730620e-02
-3.21510695e-02 1.60149783e-01 2.14893147e-02 -3.82845886e-02
1.45989448e-01 1.58292249e-01 7.58244544e-02 3.96522313e-01
-3.98388773e-01 6.20449428e-03 2.42546663e-01 -2.60287732e-01
1.64717659e-01 -1.24321856e-01 6.81726709e-02 -2.74276286e-02
-1.25487283e-01 -1.63889065e-01 2.92698648e-02 1.23230442e-01
1.19882874e-01 1.05866484e-01 1.84267804e-01 -1.10565513e-01
-2.10590884e-01 1.21882586e-02 8.51344876e-03 1.94185689e-01
5.85921034e-02 4.94556054e-02 2.12238431e-01 -5.09564996e-01
-3.28018159e-01 -1.78813078e-02 -2.99931765e-01 -2.11669654e-01
5.21721058e-02 6.09412789e-02 2.07320631e-01 8.31936523e-02
-2.09688351e-01 -1.87082580e-04 -1.42284408e-01 -6.08472563e-02]
[-2.67648339e-01 3.77257615e-01 -3.67589504e-01 4.40428883e-01
-2.93904513e-01 -1.54660210e-01 -1.52553350e-01 1.66283622e-01
7.23016635e-02 -5.92147820e-02 1.15891494e-01 -3.28541547e-01
6.15414232e-02 -1.05647065e-01 3.12377661e-01 3.11838657e-01
-6.04985952e-02 1.62625775e-01 1.24153405e-01 -3.06676924e-01
-4.26172577e-02 1.09481081e-01 1.06710918e-01 3.09303731e-01
2.91411936e-01 4.09397393e-01 1.39634579e-01 3.17679256e-01
9.52589884e-03 8.11550394e-02 -2.33638585e-01 -1.65312901e-01
-1.46116212e-01 -5.98626174e-02 -7.03526586e-02 2.66706377e-01
1.24516964e-01 7.42204934e-02 -1.81595474e-01 -1.76071122e-01
2.25278378e-01 3.99431646e-01 -7.47859553e-02 -1.43579349e-01
-1.82527732e-02 -2.58512676e-01 -1.80932239e-01 1.79400817e-02
-8.25663432e-02 3.42711993e-02 1.83398575e-02 2.79354781e-01
-3.88162695e-02 3.08591742e-02 3.06969106e-01 -2.46708021e-01
1.51661798e-01 -3.37232023e-01 -1.74425498e-01 -1.59780517e-01
1.56771298e-02 -1.12598367e-01 -3.81872207e-01 -4.74575490e-01]], shape=(2, 64), dtype=float32)
经过LSTM层前向传播后,默认只会返回最后一个时间戳的输出,如果需要返回每个时间戳上面的输出,需要设置return_sequences=True
标志。例如:
x = tf.random.normal([2, 80, 100])
xt = x[:, 0, :] # 得到一个时间戳的输入
# 创建一层LSTM层时,设置返回每个时间戳上的输出
layer = layers.LSTM(64, return_sequences=True)
# 前向计算,每个时间戳上的输出自动进行了concat,拼成一个张量
out = layer(x)
print(out)
运行结果如下所示:
tf.Tensor(
[[[-0.1717405 -0.08930297 0.06967232 ... 0.19747356 0.03752691
0.03109626]
[-0.20078596 -0.21639289 0.02701378 ... 0.18499883 -0.19541116
-0.02364185]
[-0.19680767 0.27665883 0.01163184 ... 0.06681632 0.13440426
-0.25573087]
...
[-0.25789988 0.03805498 -0.20123775 ... 0.12448035 -0.08662152
-0.08548131]
[-0.08856305 -0.02335903 0.16257113 ... 0.01324088 0.06416073
0.03601749]
[-0.27191302 0.11080401 0.39651835 ... 0.00711586 0.0080503
-0.07360849]]
[[-0.05843922 0.09174234 0.1052835 ... 0.18925342 0.05077972
-0.03099187]
[ 0.06802684 0.31600884 0.12164785 ... 0.13088599 0.00310348
-0.02805575]
[-0.15079215 0.25950116 0.03152351 ... -0.0091973 -0.02098833
0.06655266]
...
[ 0.3425578 0.1668575 -0.15170851 ... 0.3681929 -0.04511892
-0.14054409]
[ 0.13910554 -0.02730367 -0.13297835 ... 0.32460186 -0.00657132
-0.02183609]
[ 0.17034388 -0.00563782 -0.01790518 ... 0.15904477 -0.02315805
-0.04788382]]], shape=(2, 80, 64), dtype=float32)
此时返回的out包含了所有时间戳上面的状态输出,它的shape是
[
2
,
80
,
64
]
[2,80,64]
[2,80,64],其中的80代表了80个时间戳。
对于多层神经网络,可以通过Sequential容器包裹多层LSTM层,并设置所有非末层网络return_sequences=True
,这时因为非末层的LSTM层需要上一层在所有时间戳的输出作为输入。例如:
x = tf.random.normal([2, 80, 100])
xt = x[:, 0, :] # 得到一个时间戳的输入
# 和CNN网络一样,LSTM也可以简单地层层堆叠
net = keras.Sequential([
layers.LSTM(64, return_sequences=True), # 非末层需要返回所有时间戳输出
layers.LSTM(64)
])
# 一次通过网络模型,即可得到最末层、最后一个时间戳的输出
out = net(x)
print(out)
运行结果如下所示:
tf.Tensor(
[[ 0.00495507 0.12214414 -0.00160048 -0.02036193 0.0297945 0.09253313
-0.00228602 0.02402182 0.08100939 0.1723831 -0.05328411 -0.05486928
0.0479381 0.10523208 0.02489639 0.04800454 0.08452084 -0.01752731
-0.06765405 -0.08004566 -0.01247124 0.04532595 0.01210674 0.07003723
-0.04933583 -0.13701366 0.13280801 0.10668018 -0.01475972 -0.0987322
-0.00833084 -0.06607937 0.12339266 -0.03158065 -0.2647863 0.05131374
-0.0744001 -0.02532054 0.00673654 -0.13447735 0.0298725 -0.0697123
-0.10419335 0.02265518 -0.10876047 0.00242805 -0.07227369 -0.02027368
-0.03992875 0.18170744 0.0242668 -0.07278123 -0.04417139 -0.07520955
-0.01291837 0.03547625 0.09313367 0.05298894 0.03429606 0.09236071
-0.09987302 -0.03809222 -0.01680355 0.01949273]
[ 0.01952335 -0.04352669 0.13416202 -0.01402709 0.15527318 -0.11830826
0.02282524 0.002989 0.11030477 0.09961697 -0.11718795 -0.03884947
0.16997313 0.03634973 0.00700031 0.07569288 -0.02721572 -0.06567027
0.00273777 -0.16852596 0.0967081 -0.12281822 0.0451036 0.08031995
0.00726448 -0.15765832 0.03930927 0.06194751 0.09278093 0.11886592
0.01556191 0.07962804 0.10291664 0.17924826 0.00687183 -0.02100213
0.02635089 0.02811113 -0.01453197 -0.03181836 0.10666267 0.05558187
0.09320072 0.22136413 0.00443991 -0.00453833 0.1726003 0.04420204
-0.16399541 0.06623511 0.1213371 0.05969626 -0.15529574 -0.07246035
-0.00795945 0.09380019 0.04675768 -0.00230794 -0.05555401 -0.04707287
0.03893764 -0.03569769 0.10812693 -0.05769645]], shape=(2, 64), dtype=float32)