TensorFlow2学习-RNN

simple RNN 单元的建立

cell = layers.SimpleRNNCell(3)
cell.build(input_shape=(None, 4))

In [10]: cell.trainable_variables
[<tf.Variable 'kernel:0' shape=(4, 3) dtype=float32, numpy=
 array([[-0.68201065, -0.29878205,  0.6812657 ],
        [-0.6166027 , -0.13658696, -0.6273672 ],
        [-0.5876779 ,  0.03141975, -0.27967763],
        [ 0.65881634,  0.06509167, -0.6725838 ]], dtype=float32)>,
 <tf.Variable 'recurrent_kernel:0' shape=(3, 3) dtype=float32, numpy=
 array([[ 0.5435171 ,  0.7652735 ,  0.34488472],
        [-0.5323441 ,  0.63193214, -0.5632686 ],
        [ 0.6489983 , -0.12254879, -0.75085485]], dtype=float32)>,
 <tf.Variable 'bias:0' shape=(3,) dtype=float32, numpy=array([0., 0., 0.], dtype=float32)>]

x = tf.random.normal([4, 80, 100])
xt = x[:,0,:]
cell = layers.SimpleRNNCell(64)
out, h1 = cell(xt, h0) # 前向计算
In [16]: out.shape, h1[0].shape
Out[16]: (TensorShape([4, 64]), TensorShape([4, 64]))
In [17]: print(id(out), id(h1[0]))
Out[17]:1720317129832 1720317129832

h = h0
# 在序列长度的维度解开输入,得到xt:[b,f]
for xt in tf.unstack(x, axis=1):
    out, h = cell(xt, h) # 前向计算
# 最终输出可以聚合每个时间戳上的输出,也可以只取最后时间戳的输出
out = out


#省去中间过程,直接使用
In [24]: layer = keras.layers.SimpleRNN(64)
    ...: x = tf.random.normal([4, 80, 100])
    ...: out = layer(x)
    ...: out.shape
Out[24]: TensorShape([4, 64])

多层 RNN 单元的建立

x = tf.random.normal([4,80,100])
xt = x[:,0,:] # 取第一个时间戳的输入x0
#构建2个Cell,先cell0,后cell1
cell0 = layers.SimpleRNNCell(64)
cell1 = layers.SimpleRNNCell(64)
h0 = [tf.zeros([4,64])] # cell0的初始状态向量
h1 = [tf.zeros([4,64])] # cell1的初始状态向量

out0, h0 = cell0(xt, h0)
out1, h1 = cell1(out0, h1)

#%%
for xt in tf.unstack(x, axis=1):
    # xtw作为输入,输出为out0
    out0, h0 = cell0(xt, h0)
    # 上一个cell的输出out0作为本cell的输入
    out1, h1 = cell1(out0, h1)
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值