RNN在深度学习中占据重要地位,我们常常调用tensorflow的包就可以完成RNN的构建与训练,但通用的RNN并不总是能满足我们的需求,若要改动,必先知其细。
也许你会说,我自己用for循环写个rnn的实现不就好了嘛,当然可以啊。但内置的函数一般都比for循环快,用 while_loop 的好处是速度快效率高,因为它是一个tf的内置运算,会构建入运算图的,循环运行的时候不会再与python作交互。
下面我们根据源码对RNN的实现一探究竟。在探究之前,先来说一下什么叫动态RNN,我们都知道RNN全名是循环神经网络,循环嘛,自然是动态的,通过循环的方式动态生成一个个的token,直至整句话生成完毕时停止。这里用到的知识点是tf.while_loop(),如下
#举例:循环变量a,b,c;f(.),g(.),h(.)是函数
tf.while_loop(
condition,
body,
loop_vars=(init_a,init_b,init_c)
)
def condition(unuesd_a,b,unuesd_c):
# 即使a,c变量用不到,也要写在condition函数的参数中
return b>1 # 返回bool类型的值
def body(a,b,c):
next_a=f(a)
next_b=g(b)
next_c=h(c)
return next_a,next_b,next_c
这个函数的功能是:当condition函数return的结果为True时,进入循环体body()进行计算,并返回更新后的变量的值;如果condition函数return的结果为False,循环结束。
目录
第一板斧 负责决定当前时间步的输出(sample函数)和下一时间步的输入(next_inputs函数)。
第三板斧负责模拟RNN在每个时间步的情况,并在合适的时刻(比如遇到eos或者达到指定的最大长度)停止。
1 tensorflow 版本
import tensorflow as tf
tf.__version__ # tensorflow版本为1.12.0
2 动态RNN实现“三板斧”
如果定制自己需要的动态RNN,只需要修改三板斧中的对应函数,即可将自己的想法融入tf框架中,无需自己从0实现一个动态RNN,原因有二,一是方便,二是自己从0实现的不一定比tf的写的好emm
第一板斧 负责决定当前时间步的输出(sample函数)和下一时间步的输入(next_inputs函数)。
helper = tf.contrib.seq2seq.TrainingHelper(inputs=input_embed,…) # input_embed是rnn输入字符的embedding
上述函数在文件helper.py中,是专用于训练时候的helper,除此之外,helper.py中还有适用于inference时候的helper,一起来看看源码(下面为源码的重要部分截取,不是完整的helper.py文件,下同),关键是sample函数和next_inputs函数的实现。
# helper.py中所有的class,除了用于训练的TrainingHelper,
# 还有一些用于推断时候的helper,甚至可以自定义,即CustomHelper。
# 对于每个helper,关键在于sample函数和next_inputs函数的实现。
__all__ = [
"Helper",
"TrainingHelper",
"GreedyEmbeddingHelper",
"SampleEmbeddingHelper",
"CustomHelper",
"ScheduledEmbeddingTrainingHelper",
"ScheduledOutputTrainingHelper",
"InferenceHelper",
]
#训练阶段 以TrainingHelper为例进行分析
class TrainingHelper(Helper):
def __init__(self, inputs, sequence_length, time_major=False, name=None):
initial部分的源码不进行粘贴
def sample(self, time, outputs, name=None, **unused_kwargs):
# 采样得到当前时间步的输出token
with ops.name_scope(name, "TrainingHelperSample", [time, outputs]):
sample_ids = math_ops.cast(
math_ops.argmax(outputs, axis=-1), dtypes.int32) # 取概率最大的token作为输出
return sample_ids
def next_inputs(self, time, outputs, state, name=None, **unused_kwargs):
"""next_inputs_fn for TrainingHelper."""
with ops.name_scope(name, "TrainingHelperNextInputs",
[time, outputs, state]):
next_time = time + 1
finished = (next_time >= self._sequence_length)
all_finished = math_ops.reduce_all(finished)
def read_from_ta(inp):
return inp.read(next_time)
# 若rnn未finished,则取当前时间步输出真值作为下一步的输入。因为训练阶段是有标签的
next_inputs = control_flow_ops.cond(
all_finished, lambda: self._zero_inputs,
lambda: nest.map_structure(read_from_ta, self._input_tas))
return (finished, next_inputs, state)
# 推断阶段的helper以GreedyEmbeddingHelper为例进行分析
class GreedyEmbeddingHelper(Helper):
def sample(self, time, outputs, state, name=None):
"""sample for GreedyEmbeddingHelper."""
del time, state # unused by sample_fn
# Outputs are logits, use argmax to get the most probable id
if not isinstance(outputs, ops.Tensor):
raise TypeError("Expected outputs to be a single Tensor, got: %s" %
type(outputs))
sample_ids = math_ops.argmax(outputs, axis=-1, output_type=dtypes.int32)
return sample_ids
def next_inputs(self, time, outputs, state, sample_ids, name=None):
"""next_inputs_fn for GreedyEmbeddingHelper."""
del time, outputs # unused by next_inputs_fn
finished = math_ops.equal(sample_ids, self._end_token)
all_finished = math_ops.reduce_all(finished)
# 因为是推断阶段,所以把当前时间步的输出的预测值作为下一步的输入。
# sample_ids是token id,所以用_embedding_fn函数得到其embedding后再作为next_inputs
next_inputs = control_flow_ops.cond(
all_finished,
# If we're finished, the next_inputs value doesn't matter
lambda: self._start_inputs,
lambda: self._embedding_fn(sample_ids))
return (finished, next_inputs, state)
第二板斧 负责执行一个时间步(step函数),
调用cell得到该时间步的输出概率,调用helper得到该时间步的输出token id和下一步的输入token的embedding。
decoder = tf.contrib.seq2seq.BasicDecoder(cell=rnn_cell, helper=helper,…)
上述函数在basic_decoder.py中,BasicDecoder类继承于Decoder类(Decoder类在decoder.py文件中,和dynamic_decode函数在一个文件中),实现了Decoder类中的step函数。
其他的Decoder比如BeamSearchDecoder也继承于Decoder类,实现了Decoder类中的step函数。
所以,如果想自己实现一个decoder的话,继承Decoder类并实现step函数即可。
def step(self, time, inputs, state, name=None):
"""Perform a decoding step.
Args:
time: scalar `int32` tensor.
inputs: A (structure of) input tensors.
state: A (structure of) state tensors and TensorArrays.
name: Name scope for any created operations.
Returns:
`(outputs, next_state, next_inputs, finished)`.
"""
with ops.name_scope(name, "BasicDecoderStep", (time, inputs, state)):
cell_outputs, cell_state = self._cell(inputs, state)
if self._output_layer is not None:
cell_outputs = self._output_layer(cell_outputs)
sample_ids = self._helper.sample(
time=time, outputs=cell_outputs, state=cell_state)
(finished, next_inputs, next_state) = self._helper.next_inputs(
time=time,
outputs=cell_outputs,
state=cell_state,
sample_ids=sample_ids)
outputs = BasicDecoderOutput(cell_outputs, sample_ids)
return (outputs, next_state, next_inputs, finished)
可以看出,step函数中调用了cell来得到当前时间步的输出,这里的cell是rnn_cell,定义了RNN的结构,所以了解cell的输入与输出是什么很重要,这样才能正确调用。如果想了解常用rnn_cell的结构,可以阅读Tensorflow RNN结构 解读
第三板斧负责模拟RNN在每个时间步的情况,并在合适的时刻(比如遇到eos或者达到指定的最大长度)停止。
final_outputs, final_state, final_sequence_lengths = tf.contrib.seq2seq.dynamic_decode(decoder=decoder,…)
上述函数位于decoder.py文件中,dynamic_decode是一个loop(循环)来得到全部时间步的情况,每个时间步都调用decoder。
"""
condition是判断是否停止的条件,body中会调用decoder.step()来得到相关信息。
loop_vars是在循环中不断变化更新的变量,这些变量需要输入到body函数中,
在body函数中计算更新并return,以作为下一个循环body函数的输入。
这里res的内容其实就是body函数返回的内容,也就是loop_vars的值。
"""
def condition(unused_time, unused_outputs_ta, unused_state, unused_inputs,
finished, unused_sequence_lengths,):
return math_ops.logical_not(math_ops.reduce_all(finished))
res = control_flow_ops.while_loop(
condition,
body,
loop_vars=(
initial_time,
initial_outputs_ta,
initial_state,
initial_inputs,
initial_finished,
initial_sequence_lengths,
),
parallel_iterations=parallel_iterations,
maximum_iterations=maximum_iterations,
swap_memory=swap_memory)
所以,如果想定制自己需要的动态RNN,要想清楚loop_vars有哪些,然后写到loop_vars中哦~
我们来分析下停止条件condition()函数,finished的shape是(batch_size,),math_ops.reduce_all()是Computes the "logical and" of elements across dimensions of a tensor,math_ops.logical_not()是逻辑非。若要循环停止,则math_ops.logical_not()为False,即math_ops.reduce_all()为True,即finished中每个元素都为True。我们知道,循环结束的标志是句子解码结束,每个finished中第i个元素为True,代表该batch中第i个句子已经解码结束,所以,结论是,只有当batch中所有句子都解码结束,才会停止循环,即一个 batch 中的句子长度不相同时,得到的 dynamic_length 应该是某个 batch 中最长的一句的长度。