文章目录
论文地址: https://arxiv.org/pdf/1706.03762.pdf
代码参考:https://wmathor.com/index.php/archives/1455/
备注:该代码中对Transformer模型构建时均不含有dropout层。
数据预处理
采用了两对德语→英语翻译的句子,每个字的索引通过手动硬编码,降低代码阅读难度。构建编码器的输入、解码器的输入、解码器的输出即真实标签。
import math
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as Data
# S: 开始标志
# E: 结束标志
# P: 如果当前批处理数据长度小于最大长度(自己设置的),将填充空白字符
sentences = [
# enc_input 编码端输入 dec_input 解码端输入 dec_output 解码端的真实标签
['ich mochte ein bier P', 'S i want a beer .', 'i want a beer . E'],
['ich mochte ein cola P', 'S i want a coke .', 'i want a coke . E']
]
# 构建源数据词表和目标数据词表
# Padding Should be Zero
src_vocab = {'P': 0, 'ich': 1, 'mochte': 2, 'ein': 3, 'bier': 4, 'cola': 5}
src_vocab_size = len(src_vocab) # 6
tgt_vocab = {'P': 0, 'i': 1, 'want': 2, 'a': 3, 'beer': 4, 'coke': 5, 'S': 6, 'E': 7, '.': 8}
tgt_vocab_size = len(tgt_vocab) # 9
# 索引转化为单词:{0:'P',1:'i',2:'want',...,8:'.'},用于预测
idx2word = {i: w for i, w in enumerate(tgt_vocab)} # i是index,w是key
src_len = 5 # enc_input max sequence length
tgt_len = 6 # dec_input(=dec_output) max sequence length
# 构建编码器输入enc_inputs,解码器输入dec_inputs,解码器输出dec_outputs即真实标签
def make_data(sentences):
enc_inputs, dec_inputs, dec_outputs = [], [], []
for i in range(len(sentences)):
enc_input = [[src_vocab[n] for n in sentences[i][0].split()]]
dec_input = [[tgt_vocab[n] for n in sentences[i][1].split()]]
dec_output = [[tgt_vocab[n] for n in sentences[i][2].split()]]
enc_inputs.extend(enc_input) # [[1, 2, 3, 4, 0], [1, 2, 3, 5, 0]]
dec_inputs.extend(dec_input) # [[6, 1, 2, 3, 4, 8], [6, 1, 2, 3, 5, 8]]
dec_outputs.extend(dec_output) # [[1, 2, 3, 4, 8, 7], [1, 2, 3, 5, 8, 7]]
return torch.LongTensor(enc_inputs), torch.LongTensor(dec_inputs), torch.LongTensor(dec_outputs)
enc_inputs, dec_inputs, dec_outputs = make_data(sentences) # 输出为张量
# enc_inputs: [batch_size, src_len]=[2,5]
# dec_inputs/dec_outputs: [batch_size, tgt_len]=[2,6]
class MyDataSet(Data.Dataset):
def __init__(self, enc_inputs, dec_inputs, dec_outputs):
super(MyDataSet, self).__init__()
self.enc_inputs = enc_inputs
self.dec_inputs = dec_inputs
self.dec_outputs = dec_outputs
def __len__(self):
return self.enc_inputs.shape[0] # 2
def __getitem__(self, idx):
return self.enc_inputs[idx], self.dec_inputs[idx], self.dec_outputs[idx]
# 由于只有两个句子,这里batch_size设置为2
loader = Data.DataLoader(MyDataSet(enc_inputs, dec_inputs, dec_outputs), batch_size=2, shuffle=True)
Positional Encoding
每个位置的变化方式如式:
P
E
(
p
o
s
,
2
i
)
=
s
i
n
(
p
o
s
/
1000
0
2
i
/
d
m
o
d
e
l
)
P
E
(
p
o
s
,
2
i
+
1
)
=
c
o
s
(
p
o
s
/
1000
0
2
i
/
d
m
o
d
e
l
)
PE_{(pos,2i)}=sin(pos/10000^ {2i/d_ {model}}) \\ PE_{(pos,2i+1)}=cos(pos/10000^ {2i/d_ {model}})
PE(pos,2i)=sin(pos/100002i/dmodel)PE(pos,2i+1)=cos(pos/100002i/dmodel)
class PositionalEncoding(nn.Module):
def __init__(self, d_model, dropout=0.1, max_len=5000):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model) # 初始化pe
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) # 构建pos
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term) # 偶数用sin
pe[:, 1::2] = torch.cos(position * div_term) # 奇数用cos
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer('pe', pe)
def forward(self, x):
'''
x: 词向量序列[seq_len, batch_size, d_model]
'''
x = x + self.pe[:x.size(0), :]
return self.dropout(x)
这里画出图来看一下位置编码:
import matplotlib.pyplot as plt
plt.figure(figsize=(15, 5))
pe = PositionalEncoding(20, 0)
y = pe.forward((torch.zeros(100, 1, 20)))
plt.plot(np.arange(100), y[:, 0, 4:8].data.numpy())
plt.legend(["dim %d"%p for p in [4,5,6,7]])
None
模型参数
d_model = 512 # Embedding Size
d_ff = 2048 # FeedForward dimension
d_k = d_v = 64 # dimension of K(=Q), V
n_layers = 6 # number of Encoder of Decoder Layer
n_heads = 8 # number of heads in Multi-Head Attention
get_attn_pad_mask
def get_attn_pad_mask(seq_q, seq_k):
"""
seq_q: [batch_size, len_q]
seq_k: [batch_size, len_k]
seq_q 和 seq_k 不一定一致,len_q与len_k可能不相等
"""
batch_size, len_q = seq_q.size()
batch_size, len_k = seq_k.size()
# eq(zero) is PAD token, 为0设置为 True
pad_attn_mask = seq_k.data.eq(0).unsqueeze(1) # [batch_size, 1, len_k] 只使用seq_k的pad信息
return pad_attn_mask.expand(batch_size, len_q, len_k) # [batch_size, len_q, len_k]
打印出来看一下效果:
dec_enc_attn_mask = get_attn_pad_mask(dec_inputs, enc_inputs)
# 在交互注意力层,只用到了enc_inputs的pad信息,没有用到解码端的pad信息
dec_enc_attn_mask
# 输出 shape:(2,6,5)
tensor([[[False, False, False, False, True],
[False, False, False, False, True],
[False, False, False, False, True],
[False, False, False, False, True],
[False, False, False, False, True],
[False, False, False, False, True]],
[[False, False, False, False, True],
[False, False, False, False, True],
[False, False, False, False, True],
[False, False, False, False, True],
[False, False, False, False, True],
[False, False, False, False, True]]])
get_attn_subsequence_mask
# 解码端 Masked Multi-Head Attention 的 Masked来源,便于并行计算
def get_attn_subsequence_mask(seq):
"""
seq: 输入的是dec_inputs [batch_size, tgt_len]
"""
attn_shape = [seq.size(0), seq.size(1), seq.size(1)] # [batch_size, tgt_len, tgt_len]
subsequence_mask = np.triu(np.ones(attn_shape), k=1) # 上三角为1的矩阵,k=1设置对角线元素为0
subsequence_mask = torch.from_numpy(subsequence_mask).byte() # 变为张量
return subsequence_mask # [batch_size, tgt_len, tgt_len]
打印出来看一下效果:
get_attn_subsequence_mask(dec_inputs)
# 输出 shape:(2,6,6)
tensor([[[0, 1, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1],
[0, 0, 0, 1, 1, 1],
[0, 0, 0, 0, 1, 1],
[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 0]],
[[0, 1, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1],
[0, 0, 0, 1, 1, 1],
[0, 0, 0, 0, 1, 1],
[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 0]]], dtype=torch.uint8)
import matplotlib.pyplot as plt
a = torch.randn((5, 20)) # 随机生成标准正态分布数 [batch_size, len]
plt.figure(figsize=(5, 5))
plt.imshow(get_attn_subsequence_mask(a)[0]) # [batch_size, len, len] 显示第0个
None
Scaled Dot-Product Attention
Scaled Dot-Product Attention 是 Multi-Head Attention 的一部分。
A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K T d k ) V Attention(Q,K,V)=softmax(\frac{QK^{T}}{\sqrt{d_{k}}})V Attention(Q,K,V)=softmax(dkQKT)V
class ScaledDotProductAttention(nn.Module):
def __init__(self):
super(ScaledDotProductAttention, self).__init__()
def forward(self, Q, K, V, attn_mask):
"""
Q: [batch_size, n_heads, len_q, d_k]
K: [batch_size, n_heads, len_k, d_k]
V: [batch_size, n_heads, len_v(=len_k), d_v]
attn_mask: [batch_size, n_heads, seq_len, seq_len]
"""
scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(d_k) # [,,len_q,d_k]*[,,d_k,len_k]=[,,len_q,len_k]
# scores : [batch_size, n_heads, len_q, len_k]
scores.masked_fill_(attn_mask, -1e9) # mask is True的位置设置为负无穷,经过softmax后为0
# 掩码 attn_mask 与 scores 的维度相同 [batch_size, n_heads, len_q, len_k]
attn = nn.Softmax(dim=-1)(scores) # [batch_size, n_heads, len_q, len_k]
# attn 为经过softmax之后的相似概率分布,每一行概率和为1
context = torch.matmul(attn, V) # [,,len_q,len_k]*[,,len_v(=len_k),d_v]=[,,len_q,d_v]
# context: QKV经过自注意力机制计算后的值, [batch_size, n_heads, len_q, d_v]
return context, attn
Multi-Head Attention
class MultiHeadAttention(nn.Module):
def __init__(self):
super(MultiHeadAttention, self).__init__()
self.W_Q = nn.Linear(d_model, d_k * n_heads, bias=False)
self.W_K = nn.Linear(d_model, d_k * n_heads, bias=False)
self.W_V = nn.Linear(d_model, d_v * n_heads, bias=False)
self.fc = nn.Linear(n_heads * d_v, d_model, bias=False)
def forward(self, input_Q, input_K, input_V, attn_mask):
"""
input_Q: [batch_size, len_q, d_model]
input_K: [batch_size, len_k, d_model]
input_V: [batch_size, len_v(=len_k), d_model]
attn_mask: [batch_size, seq_len, seq_len]
:return: 经过多头注意力+残差+LayerNorm后的输出,保持和input_Q相同的维度
"""
residual, batch_size = input_Q, input_Q.size(0)
# (B, S, D) -proj-> (B, S, D_new) -split-> (B, S, H, W) -trans-> (B, H, S, W)
Q = self.W_Q(input_Q).view(batch_size, -1, n_heads, d_k).transpose(1, 2) # Q: [batch_size, n_heads, len_q, d_k]
K = self.W_K(input_K).view(batch_size, -1, n_heads, d_k).transpose(1, 2) # K: [batch_size, n_heads, len_k, d_k]
V = self.W_V(input_V).view(batch_size, -1, n_heads, d_v).transpose(1, 2) # V: [batch_size, n_heads, len_v(=len_k), d_v]
attn_mask = attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1)
# attn_mask : [batch_size, n_heads, seq_len, seq_len]
# repeat(): 在第2维复制n_heads次,在其他维是1次。
# context: [batch_size, n_heads, len_q, d_v], attn: [batch_size, n_heads, len_q, len_k]
context, attn = ScaledDotProductAttention()(Q, K, V, attn_mask)
context = context.transpose(1, 2).reshape(batch_size, -1, n_heads * d_v) # 这一步是图中的 cancat
# context: [batch_size, len_q, n_heads * d_v]
output = self.fc(context) # [batch_size, len_q, d_model]
return nn.LayerNorm(d_model).cuda()(output + residual), attn # 经过残差和LayerNorm不改变维度
Feed Forward Net
# 前馈神经网络,输入输出维度不变
class PoswiseFeedForwardNet(nn.Module):
def __init__(self):
super(PoswiseFeedForwardNet, self).__init__()
self.fc = nn.Sequential(
nn.Linear(d_model, d_ff, bias=False),
nn.ReLU(),
nn.Linear(d_ff, d_model, bias=False)
)
def forward(self, inputs):
"""
inputs: [batch_size, seq_len, d_model]
"""
residual = inputs
output = self.fc(inputs)
return nn.LayerNorm(d_model).cuda()(output + residual) # [batch_size, seq_len, d_model]
Encoder Layer
# 包含多头自注意力机制+前馈神经网络
class EncoderLayer(nn.Module):
def __init__(self):
super(EncoderLayer, self).__init__()
self.enc_self_attn = MultiHeadAttention() # 命名:编码器-自注意力
self.pos_ffn = PoswiseFeedForwardNet()
def forward(self, enc_inputs, enc_self_attn_mask):
"""
enc_inputs: [batch_size, src_len, d_model]
enc_self_attn_mask: [batch_size, src_len, src_len]
"""
enc_outputs, attn = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs, enc_self_attn_mask) # Q,K,V同源
# enc_outputs: [batch_size, src_len, d_model], attn: [batch_size, n_heads, src_len, src_len]
enc_outputs = self.pos_ffn(enc_outputs) # enc_outputs: [batch_size, src_len, d_model]
return enc_outputs, attn # enc_outputs 的维度与 enc_inputs 维度相同
Encoder
# Encoder 部分包含三个部分:词向量embedding,位置编码,n_layers 层EncoderLayer(注意力层+FFN)
class Encoder(nn.Module):
def __init__(self):
super(Encoder, self).__init__()
self.src_emb = nn.Embedding(src_vocab_size, d_model)
self.pos_emb = PositionalEncoding(d_model)
self.layers = nn.ModuleList([EncoderLayer() for _ in range(n_layers)]) # 使用ModuleList堆叠多个EncoderLayer
def forward(self, enc_inputs):
"""
enc_inputs: torch.Size([batch_size, src_len])
"""
enc_outputs = self.src_emb(enc_inputs) # [batch_size, src_len, d_model]
enc_outputs = self.pos_emb(enc_outputs.transpose(0, 1)).transpose(0, 1) # [batch_size, src_len, d_model]
# 前面位置编码中的输入为[seq_len, batch_size, d_model],所以要transpose前两个维度
# 经过位置编码后,保持输入输出维度不变
enc_self_attn_mask = get_attn_pad_mask(enc_inputs, enc_inputs) # [batch_size, src_len, src_len]
enc_self_attns = []
for layer in self.layers:
enc_outputs, enc_self_attn = layer(enc_outputs, enc_self_attn_mask)
# enc_outputs: [batch_size, src_len, d_model], enc_self_attn: [batch_size, n_heads, src_len, src_len]
enc_self_attns.append(enc_self_attn) # 列表,长度为 n_layers
return enc_outputs, enc_self_attns
Decoder Layer
# 包含三个部分:掩码多头自注意力 + 编码-解码多头注意力 + FFN
class DecoderLayer(nn.Module):
def __init__(self):
super(DecoderLayer, self).__init__()
self.dec_self_attn = MultiHeadAttention() # 命名:解码-自注意力
self.dec_enc_attn = MultiHeadAttention() # 命名:解码-编码-注意力
self.pos_ffn = PoswiseFeedForwardNet()
def forward(self, dec_inputs, enc_outputs, dec_self_attn_mask, dec_enc_attn_mask):
"""
dec_inputs: [batch_size, tgt_len, d_model]
enc_outputs: [batch_size, src_len, d_model]
dec_self_attn_mask: [batch_size, tgt_len, tgt_len]
dec_enc_attn_mask: [batch_size, tgt_len, src_len]
return: dec_outputs 保持与 dec_inputs 维度相同
"""
dec_outputs, dec_self_attn = self.dec_self_attn(dec_inputs, dec_inputs, dec_inputs, dec_self_attn_mask) # Q,K,V同源
# dec_outputs: [batch_size, tgt_len, d_model], dec_enc_attn: [batch_size, h_heads, tgt_len, src_len]
dec_outputs, dec_enc_attn = self.dec_enc_attn(dec_outputs, enc_outputs, enc_outputs, dec_enc_attn_mask)
# Q来自解码器端经过掩码多头自注意力的输出, K、V来自经过6层编码层后的输出
# dec_outputs: [batch_size, tgt_len, d_model], dec_self_attn: [batch_size, n_heads, tgt_len, tgt_len]
dec_outputs = self.pos_ffn(dec_outputs) # [batch_size, tgt_len, d_model]
return dec_outputs, dec_self_attn, dec_enc_attn
Decoder
# Decoder 部分包含三个部分:词向量embedding,位置编码,n_layers 层DecoderLayer
class Decoder(nn.Module):
def __init__(self):
super(Decoder, self).__init__()
self.tgt_emb = nn.Embedding(tgt_vocab_size, d_model)
self.pos_emb = PositionalEncoding(d_model)
self.layers = nn.ModuleList([DecoderLayer() for _ in range(n_layers)])
def forward(self, dec_inputs, enc_inputs, enc_outputs):
"""
dec_inputs: [batch_size, tgt_len]
enc_intpus: [batch_size, src_len]
enc_outputs: [batsh_size, src_len, d_model]
"""
dec_outputs = self.tgt_emb(dec_inputs) # [batch_size, tgt_len, d_model]
dec_outputs = self.pos_emb(dec_outputs.transpose(0, 1)).transpose(0, 1).cuda() # [batch_size, tgt_len, d_model]
# dec_self_attn_pad_mask 自注意力机制中的 pad 部分,这个是bool类型:
dec_self_attn_pad_mask = get_attn_pad_mask(dec_inputs, dec_inputs).cuda() # [batch_size, tgt_len, tgt_len]
# dec_self_attn_subsequence_mask 做自注意层的mask部分,即当前单词之后的单词看不到,使用一个上三角为1的矩阵
dec_self_attn_subsequence_mask = get_attn_subsequence_mask(dec_inputs).cuda() # [batch_size, tgt_len, tgt_len]
# 两个矩阵相加,大于0的为1,不大于0的为0,为1的在之后就会被fill填充为无限小
dec_self_attn_mask = torch.gt((dec_self_attn_pad_mask + dec_self_attn_subsequence_mask), 0).cuda()
# 变成 bool 类型 [batch_size, tgt_len, tgt_len]
# 生成交互注意力机制中的 mask 矩阵
dec_enc_attn_mask = get_attn_pad_mask(dec_inputs, enc_inputs) # [batch_size, tgt_len, src_len] [2, 6, 5]
# 也就是说 自注意力层用的是 dec_self_attn_mask, 交互注意力层用的是 dec_enc_attn_mask
dec_self_attns, dec_enc_attns = [], []
for layer in self.layers:
dec_outputs, dec_self_attn, dec_enc_attn = layer(dec_outputs, enc_outputs, dec_self_attn_mask,
dec_enc_attn_mask)
# dec_outputs: [batch_size, tgt_len, d_model], dec_self_attn: [batch_size, n_heads, tgt_len, tgt_len], dec_enc_attn: [batch_size, n_heads, tgt_len, src_len]
dec_self_attns.append(dec_self_attn)
dec_enc_attns.append(dec_enc_attn)
return dec_outputs, dec_self_attns, dec_enc_attns
Transformer
# 包含 编码层 + 解码层 + 线性层
class Transformer(nn.Module):
def __init__(self):
super(Transformer, self).__init__()
self.encoder = Encoder().cuda()
self.decoder = Decoder().cuda()
self.projection = nn.Linear(d_model, tgt_vocab_size, bias=False).cuda()
def forward(self, enc_inputs, dec_inputs):
"""
enc_inputs: [batch_size, src_len]
dec_inputs: [batch_size, tgt_len]
"""
enc_outputs, enc_self_attns = self.encoder(enc_inputs)
# enc_outputs: [batch_size, src_len, d_model], enc_self_attns: [n_layers, batch_size, n_heads, src_len, src_len]
dec_outputs, dec_self_attns, dec_enc_attns = self.decoder(dec_inputs, enc_inputs, enc_outputs)
# dec_outpus: [batch_size, tgt_len, d_model], dec_self_attns: [n_layers, batch_size, n_heads, tgt_len, tgt_len]
# dec_enc_attn: [n_layers, batch_size, n_heads, tgt_len, src_len]
dec_logits = self.projection(dec_outputs) # dec_logits: [batch_size, tgt_len, tgt_vocab_size]
return dec_logits.view(-1, dec_logits.size(-1)), enc_self_attns, dec_self_attns, dec_enc_attns
# 展平成目标词表长度,用于计算损失 (batch_size * tgt_len, tgt_vocab_size)
模型 损失函数 优化器
损失函数中,设置了一个参数 ignore_index=0
,因为 “pad” 这个单词的索引为 0,这样设置以后,就会忽略计算 “pad” 的损失(因为本来 “pad” 也没有意义,不需要计算)。
model = Transformer().cuda()
criterion = nn.CrossEntropyLoss(ignore_index=0) # 最后的softmax在这里,用于计算交叉熵损失函数
optimizer = optim.SGD(model.parameters(), lr=1e-3, momentum=0.99) # 随机梯度下降
训练
for epoch in range(1000):
for enc_inputs, dec_inputs, dec_outputs in loader:
# enc_inputs: [batch_size, src_len] 张量
# dec_inputs: [batch_size, tgt_len]
# dec_outputs: [batch_size, tgt_len]
enc_inputs, dec_inputs, dec_outputs = enc_inputs.cuda(), dec_inputs.cuda(), dec_outputs.cuda()
outputs, enc_self_attns, dec_self_attns, dec_enc_attns = model(enc_inputs, dec_inputs)
# outputs: [batch_size * tgt_len, tgt_vocab_size]
loss = criterion(outputs, dec_outputs.view(-1)) # dec_outputs变为[batch_size * tgt_len]
print('Epoch:', '%04d' % (epoch + 1), 'loss =', '{:.6f}'.format(loss))
optimizer.zero_grad()
loss.backward()
optimizer.step()
输出结果:
Epoch: 0001 loss = 1.058965
Epoch: 0002 loss = 0.938208
Epoch: 0003 loss = 0.738537
Epoch: 0004 loss = 0.628805
Epoch: 0005 loss = 0.472079
Epoch: 0006 loss = 0.394795
......
测试
打断点观察预测过程:
# 预测时,不知道目标序列输入。因此,尝试逐字生成目标输入,然后将其输入到Transformer中。
# 预测的时候编码器中,以start_symbol作为起始输入
# 之后每一轮输出的预测值作为下一轮的输入,直至预测出'.'的index停止
def greedy_decoder(model, enc_input, start_symbol): # start_symbol=6,int
"""
:param model: Transformer Model
:param enc_input: The encoder input [1, src_len]
:param start_symbol: The start symbol. In this example it is 'S' which corresponds to index 6
:return: The target input
"""
enc_outputs, enc_self_attns = model.encoder(enc_input)
# 经过编码器之后,enc_input:(1,src_len) -> enc_outpus:(1,src_len,512)
dec_input = torch.zeros(1, 0).type_as(enc_input.data) # tensor([])
terminal = False
next_symbol = start_symbol
while not terminal: # 循环 从 ["S"] 开始,词向量表索引是tensor(6)
dec_input = torch.cat([dec_input.detach(), torch.tensor([[next_symbol]], dtype=enc_input.dtype).cuda()], -1)
# shape/data: (1,1)/([[6]]) -> (1,2)/([[6,1]]) -> (1,3)/([[6,1,2]])/ ->...
# 上一轮的预测值作为下一轮的输入
dec_outputs, _, _ = model.decoder(dec_input, enc_input, enc_outputs)
# 经过解码器之后,dec_outputs:(1,1,512)->(1,2,512)->(1,3,512)->...
projected = model.projection(dec_outputs) # (1,1,9)->(1,2,9)->(1,3,9)->...
prob = projected.squeeze(0).max(dim=-1, keepdim=False)[1] # 按照最后一维找出值最大的,即预测的字的索引
# shape/data: (1,)/tensor([1]) -> (2,)/tensor([1,2]) -> (3,)/tensor([1,2,3]) ->...
# [1] 指最后返回的是最大值位置的索引
next_word = prob.data[-1] # 选取prob的位置索引中最后一个数,tensor(1)->tensor(2)->tensor(3)->...
next_symbol = next_word
if next_symbol == tgt_vocab["."]: # 直至是".",即词向量表是8的话就终止
terminal = True
print(next_word) # tensor(1)->(2)->(3)-> (4)-> (8)
return dec_input
# Test
enc_inputs, _, _ = next(iter(loader)) # (2,5)
enc_inputs = enc_inputs.cuda()
for i in range(len(enc_inputs)): # 长为2
greedy_dec_input = greedy_decoder(model, enc_inputs[i].view(1, -1), start_symbol=tgt_vocab["S"]) # [[6,1,2,3,4]]
predict, _, _, _ = model(enc_inputs[i].view(1, -1), greedy_dec_input) # 输入:shape(1,5) 预测:shape(5,9)
predict = predict.data.max(1, keepdim=True)[1] # 找出最大值索引 (5,1)
print(enc_inputs[i], '->', [idx2word[n.item()] for n in predict.squeeze()])
输出结果:
tensor(1, device='cuda:0')
tensor(2, device='cuda:0')
tensor(3, device='cuda:0')
tensor(4, device='cuda:0')
tensor(8, device='cuda:0')
tensor([1, 2, 3, 4, 0], device='cuda:0') -> ['i', 'want', 'a', 'beer', '.']
tensor(1, device='cuda:0')
tensor(2, device='cuda:0')
tensor(3, device='cuda:0')
tensor(5, device='cuda:0')
tensor(8, device='cuda:0')
tensor([1, 2, 3, 5, 0], device='cuda:0') -> ['i', 'want', 'a', 'coke', '.']
这里不放全部代码了,只要将上面提及的代码(除位置编码等打印结果的代码)复制粘贴下来,就能运行。