tf.keras.layers.RNN
类在 TensorFlow 中是实现循环神经网络(RNN)的一种灵活且强大的方式。这个类允许在神经网络模型中构建和操作 RNN 层,提供多种选项以满足特定需求。以下是对 tf.keras.layers.RNN
类的主要特性、功能和用法的详细说明:
概览
- 基类 :继承自 TensorFlow 的
Layer
和Module
类,是 TensorFlow 高级 Keras API 的一部分。 - 功能 :主要用于处理数据序列(如时间序列、文本等),捕捉数据中的时间依赖性。
构造函数参数
-
cell :RNN 层的核心。可以是 RNN 单元的实例或 RNN 单元实例的列表。每个单元必须具有:
call(input_at_t, states_at_t)
方法,返回(output_at_t, states_at_t_plus_1)
。单元格的调用方法也可以使用可选参数常量,请参阅下文 "关于传递外部常量的注意事项 "部分。state_size
属性可以是 TensorShape 或 TensorShape 的元组/列表,以表示高维度状态。output_size
属性。它可以是一个整数,也可以是一个 TensorShape,用来表示输出的形状。出于向后兼容的原因,如果单元格中没有该属性,则将根据 state_size 的第一个元素推断出其值。get_initial_state(inputs=None, batch_size=None, dtype=None)
方法用于创建一个张量,如果用户没有通过其他方式指定初始状态,该张量将作为初始状态输入 call()。返回的初始状态应该具有 [batch_size, cell.state_size] 的形状。inputs 是 RNN 层的输入张量,它应该包含作为 shape[0] 的批量大小和 dtype。需要注意的是,shape[0] 可能会在图构建过程中被取消。batch_size 是一个标量张量,表示输入的批次大小。
-
return_sequences :如果为
True
,层输出完整的输出序列;否则,仅返回最后一个输出。 -
return_state :如果为
True
,除了输出外,还返回最后的状态。 -
go_backwards :如果设置为
True
,则反向处理输入序列。 -
stateful :如果为
True
,在批次间保持状态。 -
unroll :如果为
True
,网络将展开;否则,使用符号循环。 -
time_major :确定输入和输出张量的形状格式。
-
zero_output_for_mask :控制是否为掩码的时间步输出零。
调用参数
- inputs :输入张量。
- mask :指示是否掩码某个时间步的二进制张量。
- training :指示层在训练模式还是推理模式下的行为。
- initial_state :单元的初始状态张量。
- constants :在每个时间步传递给单元的常量张量。
输入和输出形状
- 输入形状 :形状为
[batch_size, timesteps, ...]
或在time_major
为 True 时为[timesteps, batch_size, ...]
的 N 维张量。 - 输出形状 :基于
return_sequences
和return_state
变化。可以是[batch_size, timesteps, output_size]
或[batch_size, output_size]
。
掩码
- 支持对具有变化时间步数的输入数据进行掩码。
RNN 的状态性
- 要使用状态性,请设置
stateful=True
并为模型指定固定的批次大小。
重置状态
- 使用
.reset_states()
方法重置 RNN 层的状态。
初始状态规定
- 可以符号性地或数值性地指定初始状态。
向 RNN 传递外部常量
- 可以使用
constants
关键字向 RNN 传递外部常量。
示例用法
from keras.layers import RNN
from keras import backend
class MinimalRNNCell(keras.layers.Layer):
# 实现一个最小的 RNN 单元
cell = MinimalRNNCell(32)
x = keras.Input((None, 5))
layer = RNN(cell)
y = layer(x)
# 堆叠 RNN 示例
cells = [MinimalRNNCell(32), MinimalRNNCell(64)]
layer = RNN(cells)
y = layer(x)
属性和方法
- states :存储 RNN 层的状态。
- reset_states(states=None) :重置有状态 RNN 层的记录状态。
错误处理
- 如果 RNN 层不是有状态的,则引发
AttributeError
。 - 如果 RNN 层的批次大小未知,或者输入的 numpy 数组与 RNN 层的状态大小或数据类型不兼容,则引发
ValueError
。