Tensorflow2中如何处理RNN的变长输入问题

在tensorflow2中,废弃了tf.nn.dynamic_rnn函数,在tensorflow2的文档中可以看到

Warning: THIS FUNCTION IS DEPRECATED. It will be removed in a future version. Instructions for updating: Please use keras.layers.RNN(cell), which is equivalent to this API

但是实际上keras.layers.RNN(cell)这个API中并没有sequence_length这个参数.在stackoverflow和github上搜索之后,得到如下解决办法,即要实现RNN的变长输入,需要在调用keras的RNN层时传入一个mask.测试代码如下:

"""
环境:ubuntu18.04, python3.7, tensorflow2.0
"""

import numpy as np
import tensorflow as tf


# 初始化两条输入数据,一条长度为3,一条长度为4
# 把它们均padding到长度为6
seq = np.array([[1, 2, 1, 0, 0, 0], [0, 1, 2, 1, 0, 0]], dtype=np.int32)
seq_len = np.array([3, 4], dtype=np.int32)

# 定义嵌入层
emb_weights = np.array([[0, 0], [1, 1], [2, 2]])
emb = tf.keras.layers.Embedding(3, 2, weights=[emb_weights])

# 定义一个简单的RNN层
rnn = tf.keras.layers.SimpleRNN(
    1,
    activation=None,
    use_bias=False,
    kernel_initializer='ones',   # RNN输入端到节点之间的参数初始化为1
    recurrent_initializer='ones',  # RNN节点到节点之间的参数初始化为1
    return_sequences=True)

# 定义双向RNN层
bi_rnn = tf.keras.layers.Bidirectional(rnn)


# 建立keras模型
x = tf.keras.layers.Input([None])
mask = tf.keras.layers.Input([None])  # 根据序列长度定义一个mask
emb_x = emb(x)
rnn_out = rnn(emb_x, mask=mask)  # 在调用RNN过程中加入mask
bi_rnn_out = bi_rnn(emb_x, mask=mask)

m1 = tf.keras.Model(inputs=[x, mask], outputs=rnn_out)
m2 = tf.keras.Model(inputs=[x, mask], outputs=bi_rnn_out)

# 打印模型结果
print(m1.predict([seq, tf.sequence_mask(seq_len, 6)]).squeeze())
print(m2.predict([seq, tf.sequence_mask(seq_len, 6)]).squeeze())

结果输出如下:

# RNN的输出
[[2. 6. 8. 8. 8. 8.]
 [0. 2. 6. 8. 8. 8.]]

# 双向RNN的输出
[[[2. 8.]
  [6. 6.]
  [8. 2.]
  [0. 0.]
  [0. 0.]
  [0. 0.]]

 [[0. 8.]
  [2. 8.]
  [6. 6.]
  [8. 2.]
  [0. 0.]
  [0. 0.]]]

可以得到一些结论:

1. 在单向RNN的情况下,RNN在计算到序列结束处时,会在RNN的剩余节点上简单的复制序列结束处节点的状态值,所以可以看到RNN的输出在padding处为8,因为8是序列结束处RNN节点的输出.

2. 对于双向RNN,则会在序列结束之后,输出0,所以我们可以看到第一条数据的后三个节点输出都是0,第二条数据的后两个节点输出都是0.

  • 2
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值