源码获取
https://www.tensorflow.org/addons/api_docs/python/tfa/seq2seq/BasicDecoder
github上的源码
class BasicDecoder(decoder.BaseDecoder):
"""Basic sampling decoder for training and inference.
The `tfa.seq2seq.Sampler` instance passed as argument is responsible to sample from
the output distribution and produce the input for the next decoding step. The decoding
loop is implemented by the decoder in its `__call__` method.
Example using `tfa.seq2seq.TrainingSampler` for training:
>>> batch_size = 4
>>> max_time = 7
>>> hidden_size = 32
>>> embedding_size = 48
>>> input_vocab_size = 128
>>> output_vocab_size = 64
>>>
>>> embedding_layer = tf.keras.layers.Embedding(input_vocab_size, embedding_size)
>>> decoder_cell = tf.keras.layers.LSTMCell(hidden_size)
>>> sampler = tfa.seq2seq.TrainingSampler()
>>> output_layer = tf.keras.layers.Dense(output_vocab_size)
>>>
>>> decoder = tfa.seq2seq.BasicDecoder(decoder_cell, sampler, output_layer)
>>>
>>> input_ids = tf.random.uniform(
... [batch_size, max_time], maxval=input_vocab_size, dtype=tf.int64)
>>> input_lengths = tf.fill([batch_size], max_time)
>>> input_tensors = embedding_layer(input_ids)
>>> initial_state = decoder_cell.get_initial_state(input_tensors)
>>>
>>> output, state, lengths = decoder(
... input_tensors, sequence_length=input_lengths, initial_state=initial_state)
>>>
>>> logits = output.rnn_output
>>> logits.shape
TensorShape([4, 7, 64])
Example using `tfa.seq2seq.GreedyEmbeddingSampler` for inference:
>>> sampler = tfa.seq2seq.GreedyEmbeddingSampler(embedding_layer)
>>> decoder = tfa.seq2seq.BasicDecoder(
... decoder_cell, sampler, output_layer, maximum_iterations=10)
>>>
>>> initial_state = decoder_cell.get_initial_state(batch_size=batch_size, dtype=tf.float32)
>>> start_tokens = tf.fill([batch_size], 1)
>>> end_token = 2
>>>
>>> output, state, lengths = decoder(
... None, start_tokens=start_tokens, end_token=end_token, initial_state=initial_state)
>>>
>>> output.sample_id.shape
TensorShape([4, 10])
"""
@typechecked
def __init__(
self,
cell: tf.keras.layers.Layer,
sampler: sampler_py.Sampler,
output_layer: Optional[tf.keras.layers.Layer] = None,
**kwargs,
):
"""Initialize BasicDecoder.
Args:
cell: A layer that implements the `tf.keras.layers.AbstractRNNCell`
interface.
sampler: A `tfa.seq2seq.Sampler` instance.
output_layer: (Optional) An instance of `tf.keras.layers.Layer`, i.e.,
`tf.keras.layers.Dense`. Optional layer to apply to the RNN output
prior to storing the result or sampling.
**kwargs: Other keyword arguments of `tfa.seq2seq.BaseDecoder`.
"""
keras_utils.assert_like_rnncell("cell", cell)
self.cell = cell
self.sampler = sampler
self.output_layer = output_layer
super().__init__(**kwargs)
def initialize(self, inputs, initial_state=None, **kwargs):
"""Initialize the decoder."""
# Assume the dtype of the cell is the output_size structure
# containing the input_state's first component's dtype.
self._cell_dtype = tf.nest.flatten(initial_state)[0].dtype
return self.sampler.initialize(inputs, **kwargs) + (initial_state,)
@property
def batch_size(self):
return self.sampler.batch_size
def _rnn_output_size(self):
size = tf.TensorShape(self.cell.output_size)
if self.output_layer is None:
return size
else:
# To use layer's compute_output_shape, we need to convert the
# RNNCell's output_size entries into shapes with an unknown
# batch size. We then pass this through the layer's
# compute_output_shape and read off all but the first (batch)
# dimensions to get the output size of the rnn with the layer
# applied to the top.
output_shape_with_unknown_batch = tf.nest.map_structure(
lambda s: tf.TensorShape([None]).concatenate(s), size
)
layer_output_shape = self.output_layer.compute_output_shape(
output_shape_with_unknown_batch
)
return tf.nest.map_structure(lambda s: s[1:], layer_output_shape)
@property
def output_size(self):
# Return the cell output and the id
return BasicDecoderOutput(
rnn_output=self._rnn_output_size(), sample_id=self.sampler.sample_ids_shape
)
@property
def output_dtype(self):
# Assume the dtype of the cell is the output_size structure
# containing the input_state's first component's dtype.
# Return that structure and the sample_ids_dtype from the helper.
dtype = self._cell_dtype
return BasicDecoderOutput(
tf.nest.map_structure(lambda _: dtype, self._rnn_output_size()),
self.sampler.sample_ids_dtype,
)
def step(self, time, inputs, state, training=None):
"""Perform a decoding step.
Args:
time: scalar `int32` tensor.
inputs: A (structure of) input tensors.
state: A (structure of) state tensors and TensorArrays.
training: Python boolean.
Returns:
`(outputs, next_state, next_inputs, finished)`.
"""
cell_outputs, cell_state = self.cell(inputs, state, training=training)
cell_state = tf.nest.pack_sequence_as(state, tf.nest.flatten(cell_state))
if self.output_layer is not None:
cell_outputs = self.output_layer(cell_outputs)
sample_ids = self.sampler.sample(
time=time, outputs=cell_outputs, state=cell_state
)
(finished, next_inputs, next_state) = self.sampler.next_inputs(
time=time, outputs=cell_outputs, state=cell_state, sample_ids=sample_ids
)
outputs = BasicDecoderOutput(cell_outputs, sample_ids)
return (outputs, next_state, next_inputs, finished)
class BasicDecoderOutput(
collections.namedtuple("BasicDecoderOutput", ("rnn_output", "sample_id"))
):
"""Outputs of a `tfa.seq2seq.BasicDecoder` step.
Attributes:
rnn_output: The output for this step. If the `output_layer` argument
of `tfa.seq2seq.BasicDecoder` was set, it is the output of this layer, otherwise it
is the output of the RNN cell.
sample_id: The token IDs sampled for this step, as returned by the
`sampler` instance passed to `tfa.seq2seq.BasicDecoder`.
"""
pass
总的来说,传进了一个LSTM_cell以及一个output_layer,之后BasicDecoderOutput中的step是基于前一时刻的cell输出以及当前的输入不断计算当前的输出,之后经过output_layer最终形成序列。(类似于RNN的原理)