【自然语言处理】tf.contrib.seq2seq.TrainingHelper源码解析

前言

本文衔接tf.contrib.seq2seq.dynamic_decode源码分析以及tf.contrib.seq2seq.BasicDecoder源码解析。除了TrainingHelper后面还会介绍到GreedyEmbeddingHelper。
TrainingHelper代码

正文

首先先要明确的是在训练阶段,我们需要给于解码器句子,并得到相对应的输出随后进行训练。

class TrainingHelper(Helper):
  """A helper for use during training.  Only reads inputs.
  Returned sample_ids are the argmax of the RNN output logits.
  """

  def __init__(self, inputs, sequence_length, time_major=False, name=None):
    """Initializer.
    Args:
      inputs: A (structure of) input tensors.
      sequence_length: An int32 vector tensor.
      time_major: Python bool.  Whether the tensors in `inputs` are time major.
        If `False` (default), they are assumed to be batch major.
      name: Name scope for any created operations.
    Raises:
      ValueError: if `sequence_length` is not a 1D tensor.
    """
    with ops.name_scope(name, "TrainingHelper", [inputs, sequence_length]):
      inputs = ops.convert_to_tensor(inputs, name="inputs")
      self._inputs = inputs
      if not time_major:
        inputs = nest.map_structure(_transpose_batch_time, inputs)

      self._input_tas = nest.map_structure(_unstack_ta, inputs)
      self._sequence_length = ops.convert_to_tensor(
          sequence_length, name="sequence_length")
      if self._sequence_length.get_shape().ndims != 1:
        raise ValueError(
            "Expected sequence_length to be a vector, but received shape: %s" %
            self._sequence_length.get_shape())

      self._zero_inputs = nest.map_structure(
          lambda inp: array_ops.zeros_like(inp[0, :]), inputs)

      self._batch_size = array_ops.size(sequence_length)

所以TrainingHelper接收的参数主要有一个大小为[batch_size, seqlen, embed_size]的输入inputs;以及每个句子的真实长度sequence_length,是一个[batch_size]的向量;time_major为真则把seqlen作为第一维。注意下sequence_length是一个batch_size大小的数组,指明了每个句子的真实长度(因为有些长度是padding的)。

  def initialize(self, name=None):
    with ops.name_scope(name, "TrainingHelperInitialize"):
      finished = math_ops.equal(0, self._sequence_length)
      all_finished = math_ops.reduce_all(finished)
      next_inputs = control_flow_ops.cond(
          all_finished, lambda: self._zero_inputs,
          lambda: nest.map_structure(lambda inp: inp.read(0), self._input_tas))
      return (finished, next_inputs)

这里主要是初始化,给于外界第一个输入数据。

  def next_inputs(self, time, outputs, state, name=None, **unused_kwargs):
    """next_inputs_fn for TrainingHelper."""
    with ops.name_scope(name, "TrainingHelperNextInputs",
                        [time, outputs, state]):
      next_time = time + 1
      finished = (next_time >= self._sequence_length)
      all_finished = math_ops.reduce_all(finished)
      def read_from_ta(inp):
        return inp.read(next_time)
      next_inputs = control_flow_ops.cond(
          all_finished, lambda: self._zero_inputs,
          lambda: nest.map_structure(read_from_ta, self._input_tas))
      return (finished, next_inputs, state)

在TrainingHelper的next_inputs中,我们每次读取都是inputs中time+1的数据,并且返回给之前的数据。注意这里有个finished,这里意思就是判断当前time是否大于seqlen,如果大于说明这个输出应该为0向量。

  def sample(self, time, outputs, name=None, **unused_kwargs):
    with ops.name_scope(name, "TrainingHelperSample", [time, outputs]):
      sample_ids = math_ops.cast(
          math_ops.argmax(outputs, axis=-1), dtypes.int32)
      return sample_ids

在这里也实现了一个sample函数,主要是用来采样的,取输出概率最大的词作为当前的输出词。其实在TrainingHelper比较关心next_inputs,而在推理阶段,我们更关注这个sample函数。

总结

可以看到,在Seq2seq提供了各种各样的Helper,在这个Helper中基本都提供了一个next_inputs和sample函数,但是在训练阶段我们更关注于next_inputs这个函数,因为我们只是想要输出然后用于后面的训练。

  • 3
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值