前言
本系列主要主要是记录下Tensorflow在RNN实现这一块的相关代码,不做详细解释,主要是翻译加笔记。
RNNCell
在Tensorflow中,定义了一个RNNCell的抽象类,具体的所有不同类型的RNN Cell都是基于这个类的,所以就首先讲一下这个,下面是基本的代码:
class RNNCell(object):
def __call__(self, inputs, state, scope=None):
raise NotImplementedError("Abstract method")
@property
def state_size(self):
raise NotImplementedError("Abstract method")
@property
def output_size(self):
raise NotImplementedError("Abstract method")
def zero_state(self, batch_size, dtype):
state_size = self.state_size
if nest.is_sequence(state_size):
state_size_flat = nest.flatten(state_size)
zeros_flat = [
array_ops.zeros(
array_ops.pack(_state_size_with_prefix(s, prefix=[batch_size])),
dtype=dtype)
for s