7.2.6 解码器
在本项目中,解码器的功能是为下一个输出标记生成预。
(1)编写解码器类Decoder并设置其初始化选项值,初始化程序用于创建所有必要的层。类Decoder的代码如下:
class Decoder(tf.keras.layers.Layer):
def __init__(self, output_vocab_size, embedding_dim, dec_units):
super(Decoder, self).__init__()
self.dec_units = dec_units
self.output_vocab_size = output_vocab_size
self.embedding_dim = embedding_dim
# 步骤1,嵌入层将令牌ID转换为向量
self.embedding = tf.keras.layers.Embedding(self.output_vocab_size,
embedding_dim)
# 步骤2,RNN跟踪到目前为止生成的内容.
self.gru = tf.keras.layers.GRU(self.dec_units,
return_sequences=True,
return_state=True,
recurrent_initializer='glorot_uniform')
# 步骤3,RNN输出的是对注意层的查询.
self.attention = BahdanauAttention(self.dec_units)
# 步骤4, 将'ct'转换为'at'`
self.Wc = tf.keras.layers.Dense(dec_units, activation=tf.math.tanh,
use_bias=False)
# 步骤5,这个完全连接的层为每个输出令牌生成logit.
self.fc = tf.keras.layers.Dense(self.output_vocab_size)
上述解码器类的实现流程如下:
①解码器接收完整的编码器输出。
②使用 RNN 跟踪迄今为止生成的内容。
③使用其 RNN 输出作为对编码器输出的注意力的查询,生成上下文向量。
④使用步骤③将 RNN 输出和上下文向量组合起来,生成“注意力向量”。
⑤基于“注意力向量”为下一个标记生成 logit 预测。
(2)call层的方法用于接受并返回多个张量,将它们组织成简单的容器类:
class DecoderInput(typing.NamedTuple):
new_tokens: Any
enc_output: Any
mask: Any
class DecoderOutput(typing.NamedTuple):
logits: Any
attention_weights: Any
下面call()方法的具体实现:
def call(self,
inputs: DecoderInput,
state=None) -> Tuple[DecoderOutput, tf.Tensor]:
shape_checker = ShapeChecker()
shape_checker(inputs.new_tokens, ('batch', 't'))
shape_checker(inputs.enc_output, ('batch', 's', 'enc_units'))
shape_checker(inputs.mask, ('batch', 's'))
if state is not None:
shape_checker(state, ('batch', 'dec_units'))
# Step 1. 查找嵌入项
vectors = self.embedding(inputs.new_tokens)
shape_checker(vectors, ('batch', 't', 'embedding_dim'))
# Step 2. 使用RNN处理一个步骤
rnn_output, state = self.gru(vectors, initial_state=state)
shape_checker(rnn_output, ('batch', 't', 'dec_units'))
shape_checker(state, ('batch', 'dec_units'))
# Step 3. 使用RNN输出作为对网络上的注意的查询编码器输出.
context_vector, attention_weights = self.attention(
query=rnn_output, value=inputs.enc_output, mask=inputs.mask)
shape_checker(context_vector, ('batch', 't', 'dec_units'))
shape_checker(attention_weights, ('batch', 't', 's'))
# Step 4. 使用 Step(3): 连接context_vector 和 rnn_output上下文
# [ct; ht] shape: (batch t, value_units + query_units)
context_and_rnn_output = tf.concat([context_vector, rnn_output], axis=-1)
# Step 4. 使用 (3): `at = tanh(Wc@[ct; ht])`
attention_vector = self.Wc(context_and_rnn_output)
shape_checker(attention_vector, ('batch', 't', 'dec_units'))
# Step 5. 生成logit预测:
logits = self.fc(attention_vector)
shape_checker(logits, ('batch', 't', 'output_vocab_size'))
return DecoderOutput(logits, attention_weights), state
在本实例中,编码器用于处理其整个输入序列与它的RNN单个呼叫。虽然解码器的这种实现可以实现高效训练功能,但是本实例将在循环中运行解码器,原因如下:
- 灵活性:编写循环可让您直接控制训练过程。
- 清晰:可以使用屏蔽技巧并使用layers.RNN、 或tfa.seq2seqAPI 将所有这些打包到单个调用中。但是把它写成一个循环可能会更清晰。
(3)开始使用解码,代码如下:
decoder = Decoder(output_text_processor.vocabulary_size(),embedding_dim, units)
解码器有 4 个输入:
- new_tokens:生成的最后一个令牌。使用"[START]"令牌初始化解码器。
- enc_output:由Encoder生成。
- mask:设置位置的布尔张量。
- state-state:解码器之前的输出(解码器 RNN 的内部状态)。传递None到零初始化它。原始论文从编码器的最终 RNN 状态对其进行初始化。
7.2.7 训练
现在已经拥有所有的模型组件,是时候开始模型训练的工作了。
(1)定义损失函数,代码如下:
class MaskedLoss(tf.keras.losses.Loss):
def __init__(self):
self.name = 'masked_loss'
self.loss = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True, reduction='none')
def __call__(self, y_true, y_pred):
shape_checker = ShapeChecker()
shape_checker(y_true, ('batch', 't'))
shape_checker(y_pred, ('batch', 't', 'logits'))
loss = self.loss(y_true, y_pred)
shape_checker(loss, ('batch', 't'))
mask = tf.cast(y_true != 0, tf.float32)
shape_checker(mask, ('batch', 't'))
loss *= mask
return tf.reduce_sum(loss)
(2)实施训练步骤
从一个模型类开始,整个训练过程将作为train_step该模型上的方法来实现。编写train_step()方法是对_train_step稍后将出现的实现的包装器。这个包装器包括一个开关来打开和关闭tf.function编译,使调试更容易。
class TrainTranslator(tf.keras.Model):
def __init__(self, embedding_dim, units,
input_text_processor,
output_text_processor,
use_tf_function=True):
super().__init__()
# Build the encoder and decoder
encoder = Encoder(input_text_processor.vocabulary_size(),
embedding_dim, units)
decoder = Decoder(output_text_processor.vocabulary_size(),
embedding_dim, units)
self.encoder = encoder
self.decoder = decoder
self.input_text_processor = input_text_processor
self.output_text_processor = output_text_processor
self.use_tf_function = use_tf_function
self.shape_checker = ShapeChecker()
def train_step(self, inputs):
self.shape_checker = ShapeChecker()
if self.use_tf_function:
return self._tf_train_step(inputs)
else:
return self._train_step(inputs)
(3)编写方法_preprocess()接收一批input_text,从tf.data.Dataset处理target_text。将这些原始文本输入转换为标记嵌入和掩码。
def _preprocess(self, input_text, target_text):
self.shape_checker(input_text, ('batch',))
self.shape_checker(target_text, ('batch',))
#将文本转换为令牌ID
input_tokens = self.input_text_processor(input_text)
target_tokens = self.output_text_processor(target_text)
self.shape_checker(input_tokens, ('batch', 's'))
self.shape_checker(target_tokens, ('batch', 't'))
#将ID转换为掩码
input_mask = input_tokens != 0
self.shape_checker(input_mask, ('batch', 's'))
target_mask = target_tokens != 0
self.shape_checker(target_mask, ('batch', 't'))
return input_tokens, input_mask, target_tokens, target_mask
(4)编写方法_train_step(),功能是处理除实际运行解码器之外的其余步骤。代码如下:
def _train_step(self, inputs):
input_text, target_text = inputs
(input_tokens, input_mask,
target_tokens, target_mask) = self._preprocess(input_text, target_text)
max_target_length = tf.shape(target_tokens)[1]
with tf.GradientTape() as tape:
enc_output, enc_state = self.encoder(input_tokens)
self.shape_checker(enc_output, ('batch', 's', 'enc_units'))
self.shape_checker(enc_state, ('batch', 'enc_units'))
dec_state = enc_state
loss = tf.constant(0.0)
for t in tf.range(max_target_length-1):
new_tokens = target_tokens[:, t:t+2]
step_loss, dec_state = self._loop_step(new_tokens, input_mask,
enc_output, dec_state)
loss = loss + step_loss
average_loss = loss / tf.reduce_sum(tf.cast(target_mask, tf.float32))
variables = self.trainable_variables
gradients = tape.gradient(average_loss, variables)
self.optimizer.apply_gradients(zip(gradients, variables))
return {'batch_loss': average_loss}
(5)编写方法_loop_step(),功能是执行解码器并计算增量损失和新的解码器状态 ( dec_state)
def _loop_step(self, new_tokens, input_mask, enc_output, dec_state):
input_token, target_token = new_tokens[:, 0:1], new_tokens[:, 1:2]
decoder_input = DecoderInput(new_tokens=input_token,
enc_output=enc_output,
mask=input_mask)
dec_result, dec_state = self.decoder(decoder_input, state=dec_state)
self.shape_checker(dec_result.logits, ('batch', 't1', 'logits'))
self.shape_checker(dec_result.attention_weights, ('batch', 't1', 's'))
self.shape_checker(dec_state, ('batch', 'dec_units'))
y = target_token
y_pred = dec_result.logits
step_loss = self.loss(y, y_pred)
return step_loss, dec_state
TrainTranslator._loop_step = _loop_step
(6)测试训练步骤
构建一个TrainTranslator,并使用以下Model.compile方法进行配置以进行训练:
translator = TrainTranslator(
embedding_dim, units,
input_text_processor=input_text_processor,
output_text_processor=output_text_processor,
use_tf_function=False)
translator.compile(
optimizer=tf.optimizers.Adam(),
loss=MaskedLoss(),
)
然后测试一下train_step,对于这样的文本模型,损失应该从附近开始:
np.log(output_text_processor.vocabulary_size())
for n in range(10):
print(translator.train_step([example_input_batch, example_target_batch]))
print()
在笔者机器中执行后会输出:
7.517193191416238
{'batch_loss': <tf.Tensor: shape=(), dtype=float32, numpy=7.614782>}
{'batch_loss': <tf.Tensor: shape=(), dtype=float32, numpy=7.5835567>}
{'batch_loss': <tf.Tensor: shape=(), dtype=float32, numpy=7.5252647>}
{'batch_loss': <tf.Tensor: shape=(), dtype=float32, numpy=7.361221>}
{'batch_loss': <tf.Tensor: shape=(), dtype=float32, numpy=6.7776713>}
{'batch_loss': <tf.Tensor: shape=(), dtype=float32, numpy=5.271942>}
{'batch_loss': <tf.Tensor: shape=(), dtype=float32, numpy=4.822084>}
{'batch_loss': <tf.Tensor: shape=(), dtype=float32, numpy=4.702935>}
{'batch_loss': <tf.Tensor: shape=(), dtype=float32, numpy=4.303531>}
{'batch_loss': <tf.Tensor: shape=(), dtype=float32, numpy=4.150844>}
CPU times: user 5.21 s, sys: 0 ns, total: 5.21 s
Wall time: 5.17 s
最后编码绘制损失曲线:
losses = []
for n in range(100):
print('.', end='')
logs = translator.train_step([example_input_batch, example_target_batch])
losses.append(logs['batch_loss'].numpy())
print()
plt.plot(losses)
绘制损失曲线如图7-8所示。
图7-8 绘制的损失曲线
(7)训练模型
虽然编写的自定义训练循环没有任何问题,但是在实现该Model.train_step()方法时,允许运行Model.fit并避免重写所有的样板代码。在本实例中只训练了几个周期,所以使用 acallbacks.Callback收集批次损失的历史用于绘图:
class BatchLogs(tf.keras.callbacks.Callback):
def __init__(self, key):
self.key = key
self.logs = []
def on_train_batch_end(self, n, logs):
self.logs.append(logs[self.key])
batch_loss = BatchLogs('batch_loss')
train_translator.fit(dataset, epochs=3,
callbacks=[batch_loss])
执行后会输出:
Epoch 1/3
2023-07-31 11:08:55.431052: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:801] function_optimizer failed: Invalid argument: Input 6 of node StatefulPartitionedCall/gradient_tape/while/while_grad/body/_589/gradient_tape/while/gradients/while/decoder_2/gru_5/PartitionedCall_grad/PartitionedCall was passed variant from StatefulPartitionedCall/gradient_tape/while/while_grad/body/_589/gradient_tape/while/gradients/while/decoder_2/gru_5/PartitionedCall_grad/TensorListPopBack_2:1 incompatible with expected float.
2023-07-31 11:08:55.515851: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:801] shape_optimizer failed: Out of range: src_output = 25, but num_outputs is only 25
2023-07-31 11:08:55.556380: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:801] layout failed: Out of range: src_output = 25, but num_outputs is only 25
2023-07-31 11:08:55.674137: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:801] function_optimizer failed: Invalid argument: Input 6 of node StatefulPartitionedCall/gradient_tape/while/while_grad/body/_589/gradient_tape/while/gradients/while/decoder_2/gru_5/PartitionedCall_grad/PartitionedCall was passed variant from StatefulPartitionedCall/gradient_tape/while/while_grad/body/_589/gradient_tape/while/gradients/while/decoder_2/gru_5/PartitionedCall_grad/TensorListPopBack_2:1 incompatible with expected float.
2023-07-31 11:08:55.729119: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:801] shape_optimizer failed: Out of range: src_output = 25, but num_outputs is only 25
2023-07-31 11:08:55.802715: W tensorflow/core/common_runtime/process_function_library_runtime.cc:841] Ignoring multi-device function optimization failure: Invalid argument: Input 1 of node StatefulPartitionedCall/while/body/_59/while/TensorListPushBack_56 was passed float from StatefulPartitionedCall/while/body/_59/while/decoder_2/gru_5/PartitionedCall:6 incompatible with expected variant.
1859/1859 [==============================] - 353s 187ms/step - batch_loss: 2.0502
Epoch 2/3
1859/1859 [==============================] - 333s 179ms/step - batch_loss: 1.0388
Epoch 3/3
1859/1859 [==============================] - 323s 174ms/step - batch_loss: 0.8104
<keras.callbacks.History at 0x7fc2ccb315d0>
编写如下代码绘制可视化图:
plt.plot(batch_loss.logs)
plt.ylim([0, 3])
plt.xlabel('Batch #')
plt.ylabel('CE/token')
绘制的可视化图如图7-9所示,由图中可见,跳跃主要位于纪元边界