tfa.seq2seq.BasicDecoder的源码解读

源码获取

https://www.tensorflow.org/addons/api_docs/python/tfa/seq2seq/BasicDecoder

github上的源码

class BasicDecoder(decoder.BaseDecoder):
    """Basic sampling decoder for training and inference.
    The `tfa.seq2seq.Sampler` instance passed as argument is responsible to sample from
    the output distribution and produce the input for the next decoding step. The decoding
    loop is implemented by the decoder in its `__call__` method.
    Example using `tfa.seq2seq.TrainingSampler` for training:
    >>> batch_size = 4
    >>> max_time = 7
    >>> hidden_size = 32
    >>> embedding_size = 48
    >>> input_vocab_size = 128
    >>> output_vocab_size = 64
    >>>
    >>> embedding_layer = tf.keras.layers.Embedding(input_vocab_size, embedding_size)
    >>> decoder_cell = tf.keras.layers.LSTMCell(hidden_size)
    >>> sampler = tfa.seq2seq.TrainingSampler()
    >>> output_layer = tf.keras.layers.Dense(output_vocab_size)
    >>>
    >>> decoder = tfa.seq2seq.BasicDecoder(decoder_cell, sampler, output_layer)
    >>>
    >>> input_ids = tf.random.uniform(
    ...     [batch_size, max_time], maxval=input_vocab_size, dtype=tf.int64)
    >>> input_lengths = tf.fill([batch_size], max_time)
    >>> input_tensors = embedding_layer(input_ids)
    >>> initial_state = decoder_cell.get_initial_state(input_tensors)
    >>>
    >>> output, state, lengths = decoder(
    ...     input_tensors, sequence_length=input_lengths, initial_state=initial_state)
    >>>
    >>> logits = output.rnn_output
    >>> logits.shape
    TensorShape([4, 7, 64])
    Example using `tfa.seq2seq.GreedyEmbeddingSampler` for inference:
    >>> sampler = tfa.seq2seq.GreedyEmbeddingSampler(embedding_layer)
    >>> decoder = tfa.seq2seq.BasicDecoder(
    ...     decoder_cell, sampler, output_layer, maximum_iterations=10)
    >>>
    >>> initial_state = decoder_cell.get_initial_state(batch_size=batch_size, dtype=tf.float32)
    >>> start_tokens = tf.fill([batch_size], 1)
    >>> end_token = 2
    >>>
    >>> output, state, lengths = decoder(
    ...     None, start_tokens=start_tokens, end_token=end_token, initial_state=initial_state)
    >>>
    >>> output.sample_id.shape
    TensorShape([4, 10])
    """

    @typechecked
    def __init__(
        self,
        cell: tf.keras.layers.Layer,
        sampler: sampler_py.Sampler,
        output_layer: Optional[tf.keras.layers.Layer] = None,
        **kwargs,
    ):
        """Initialize BasicDecoder.
        Args:
          cell: A layer that implements the `tf.keras.layers.AbstractRNNCell`
            interface.
          sampler: A `tfa.seq2seq.Sampler` instance.
          output_layer: (Optional) An instance of `tf.keras.layers.Layer`, i.e.,
            `tf.keras.layers.Dense`. Optional layer to apply to the RNN output
             prior to storing the result or sampling.
          **kwargs: Other keyword arguments of `tfa.seq2seq.BaseDecoder`.
        """
        keras_utils.assert_like_rnncell("cell", cell)
        self.cell = cell
        self.sampler = sampler
        self.output_layer = output_layer
        super().__init__(**kwargs)

    def initialize(self, inputs, initial_state=None, **kwargs):
        """Initialize the decoder."""
        # Assume the dtype of the cell is the output_size structure
        # containing the input_state's first component's dtype.
        self._cell_dtype = tf.nest.flatten(initial_state)[0].dtype
        return self.sampler.initialize(inputs, **kwargs) + (initial_state,)

    @property
    def batch_size(self):
        return self.sampler.batch_size

    def _rnn_output_size(self):
        size = tf.TensorShape(self.cell.output_size)
        if self.output_layer is None:
            return size
        else:
            # To use layer's compute_output_shape, we need to convert the
            # RNNCell's output_size entries into shapes with an unknown
            # batch size.  We then pass this through the layer's
            # compute_output_shape and read off all but the first (batch)
            # dimensions to get the output size of the rnn with the layer
            # applied to the top.
            output_shape_with_unknown_batch = tf.nest.map_structure(
                lambda s: tf.TensorShape([None]).concatenate(s), size
            )
            layer_output_shape = self.output_layer.compute_output_shape(
                output_shape_with_unknown_batch
            )
            return tf.nest.map_structure(lambda s: s[1:], layer_output_shape)

    @property
    def output_size(self):
        # Return the cell output and the id
        return BasicDecoderOutput(
            rnn_output=self._rnn_output_size(), sample_id=self.sampler.sample_ids_shape
        )

    @property
    def output_dtype(self):
        # Assume the dtype of the cell is the output_size structure
        # containing the input_state's first component's dtype.
        # Return that structure and the sample_ids_dtype from the helper.
        dtype = self._cell_dtype
        return BasicDecoderOutput(
            tf.nest.map_structure(lambda _: dtype, self._rnn_output_size()),
            self.sampler.sample_ids_dtype,
        )

    def step(self, time, inputs, state, training=None):
        """Perform a decoding step.
        Args:
          time: scalar `int32` tensor.
          inputs: A (structure of) input tensors.
          state: A (structure of) state tensors and TensorArrays.
          training: Python boolean.
        Returns:
          `(outputs, next_state, next_inputs, finished)`.
        """
        cell_outputs, cell_state = self.cell(inputs, state, training=training)
        cell_state = tf.nest.pack_sequence_as(state, tf.nest.flatten(cell_state))
        if self.output_layer is not None:
            cell_outputs = self.output_layer(cell_outputs)
        sample_ids = self.sampler.sample(
            time=time, outputs=cell_outputs, state=cell_state
        )
        (finished, next_inputs, next_state) = self.sampler.next_inputs(
            time=time, outputs=cell_outputs, state=cell_state, sample_ids=sample_ids
        )
        outputs = BasicDecoderOutput(cell_outputs, sample_ids)
        return (outputs, next_state, next_inputs, finished)
class BasicDecoderOutput(
    collections.namedtuple("BasicDecoderOutput", ("rnn_output", "sample_id"))
):
    """Outputs of a `tfa.seq2seq.BasicDecoder` step.
    Attributes:
      rnn_output: The output for this step. If the `output_layer` argument
         of `tfa.seq2seq.BasicDecoder` was set, it is the output of this layer, otherwise it
         is the output of the RNN cell.
      sample_id: The token IDs sampled for this step, as returned by the
        `sampler` instance passed to `tfa.seq2seq.BasicDecoder`.
    """

    pass

总的来说,传进了一个LSTM_cell以及一个output_layer,之后BasicDecoderOutput中的step是基于前一时刻的cell输出以及当前的输入不断计算当前的输出,之后经过output_layer最终形成序列。(类似于RNN的原理)

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值