8.2.8 翻译
现在模型已经训练完毕,接下来需要执行完整的text => text翻译。本实例的模型需要通过所提供的映射output_text_processor反转text => token IDs,并且还需要知道特殊令牌的 ID。这都是在新类的构造函数中实现的。总的来说,这与训练循环相似,不同之处在于每个时间步的解码器输入是来自解码器最后预测的样本。
(1)编写类Translator实现翻译功能,代码如下:
class Translator(tf.Module):
def __init__(self, encoder, decoder, input_text_processor,
output_text_processor):
self.encoder = encoder
self.decoder = decoder
self.input_text_processor = input_text_processor
self.output_text_processor = output_text_processor
self.output_token_string_from_index = (
tf.keras.layers.experimental.preprocessing.StringLookup(
vocabulary=output_text_processor.get_vocabulary(),
mask_token='',
invert=True))
#输出不应生成填充、未知或开始.
index_from_string = tf.keras.layers.experimental.preprocessing.StringLookup(
vocabulary=output_text_processor.get_vocabulary(), mask_token='')
token_mask_ids = index_from_string(['', '[UNK]', '[START]']).numpy()2
token_mask = np.zeros([index_from_string.vocabulary_size()], dtype=np.bool)
token_mask[np.array(token_mask_ids)] = True
self.token_mask = token_mask
self.start_token = index_from_string(tf.constant('[START]'))
self.end_token = index_from_string(tf.constant('[END]'))
translator = Translator(
encoder=train_translator.encoder,
decoder=train_translator.decoder,
input_text_processor=input_text_processor,
output_text_processor=output_text_processor,
)
(2)将令牌 ID 转换为文本
要实现的第一种方法是tokens_to_text将令牌 ID 转换为人类可读的文本,代码如下:
def tokens_to_text(self, result_tokens):
shape_checker = ShapeChecker()
shape_checker(result_tokens, ('batch', 't'))
result_text_tokens = self.output_token_string_from_index(result_tokens)
shape_checker(result_text_tokens, ('batch', 't'))
result_text = tf.strings.reduce_join(result_text_tokens,
axis=1, separator=' ')
shape_checker(result_text, ('batch'))
result_text = tf.strings.strip(result_text)
shape_checker(result_text, ('batch',))
return result_text
Translator.tokens_to_text = tokens_to_text
然后输入一些随机令牌 ID 并查看它生成的内容:
example_output_tokens = tf.random.uniform(
shape=[5, 2], minval=0, dtype=tf.int64,
maxval=output_text_processor.vocabulary_size())
translator.tokens_to_text(example_output_tokens).numpy()
array([b'divorce nodded', b'lid discovery', b'exhibition slam',
b'unknown jackson', b'harmful excited'], dtype=object)
(3)来自解码器预测的样本
编写函数tokens_to_text(),使用解码器的 logit 输出并从该分布中采样令牌 ID。代码如下:
def sample(self, logits, temperature):
shape_checker = ShapeChecker()
shape_checker(logits, ('batch', 't', 'vocab'))
shape_checker(self.token_mask, ('vocab',))
token_mask = self.token_mask[tf.newaxis, tf.newaxis, :]
shape_checker(token_mask, ('batch', 't', 'vocab'), broadcast=True)
logits = tf.where(self.token_mask, -np.inf, logits)
if temperature == 0.0:
new_tokens = tf.argmax(logits, axis=-1)
else:
logits = tf.squeeze(logits, axis=1)
new_tokens = tf.random.categorical(logits/temperature,
num_samples=1)
shape_checker(new_tokens, ('batch', 't'))
return new_tokens
Translator.sample = sample
(5)实现翻译循环
编写函数translate_unrolled()实现文本到文本翻译循环,将结果收集到 python 列表中,然后tf.concat再将它们连接到张量中。在整个实现过程中,将静态地展开图形以进行max_length迭代。
def translate_unrolled(self,
input_text, *,
max_length=50,
return_attention=True,
temperature=1.0):
batch_size = tf.shape(input_text)[0]
input_tokens = self.input_text_processor(input_text)
enc_output, enc_state = self.encoder(input_tokens)
dec_state = enc_state
new_tokens = tf.fill([batch_size, 1], self.start_token)
result_tokens = []
attention = []
done = tf.zeros([batch_size, 1], dtype=tf.bool)
for _ in range(max_length):
dec_input = DecoderInput(new_tokens=new_tokens,
enc_output=enc_output,
mask=(input_tokens!=0))
dec_result, dec_state = self.decoder(dec_input, state=dec_state)
attention.append(dec_result.attention_weights)
new_tokens = self.sample(dec_result.logits, temperature)
done = done | (new_tokens == self.end_token)
new_tokens = tf.where(done, tf.constant(0, dtype=tf.int64), new_tokens)
result_tokens.append(new_tokens)
if tf.executing_eagerly() and tf.reduce_all(done):
break
result_tokens = tf.concat(result_tokens, axis=-1)
result_text = self.tokens_to_text(result_tokens)
if return_attention:
attention_stack = tf.concat(attention, axis=1)
return {'text': result_text, 'attention': attention_stack}
else:
return {'text': result_text}
Translator.translate = translate_unrolled
执行后=的翻译结果的注意力可视化图如图8-10所示。
图8-10 翻译结果的注意力可视化图