transformer是一种编解码(encoder-decoer)结构,用于自然语言处理、计算机视觉等领域,编解码结构是当前大模型必包含的部分。
文章目录
编解码结构图:
transformer模块编码输入得到特征,然后解码得到输出。
transformer论文的一张经典图:
结合transformer论文和代码,模块主要包括了:
-
词嵌入模块(input embedding)
-
位置编码模块(Positional Encoding)
-
多头注意力机制模块(Multi-Head Attention)
-
层归一化模块(LayNorm)
-
残差模块
-
前馈神经网络模块(FFN)
-
交叉多头注意力机制模块(Cross Multi-Head Attention)
-
掩膜多头注意力机制模块(Masked Multi-Head Attention)
1. 词嵌入模块
词嵌入模块调用nn.Embedding,其主要作用是将每个单词表示成一个向量,方便下一步计算和处理。
class TokenEmbedding(nn.Embedding):
"""
Token Embedding using torch.nn
they will dense representation of word using weighted matrix
"""
def __init__(self, vocab_size, d_model):
"""
class for token embedding that included positional information
:param vocab_size: 字典中词的个数
:param d_model: 嵌入维度
"""
super(TokenEmbedding, self).__init__(vocab_size, d_model, padding_idx=1)
若batch_size是bs,输入词的个数是seq_len,词嵌入维度是n_model。
因此输入的维度x:tensor(bs,seq_len)
TokenEmbdding(x)后的维度:tensor(bs,seq_len,d_model)
2.位置编码模块
位置编码模块根据每个词的位置进行编码得到位置向量,原理是采用三角函数编码。如下图,左边采用正弦函数编码,右边采用余弦函数编码,最后进行拼接。
其中行表示嵌入维度,列表示词的位置。
另一种位置编码方法是正弦函数和余弦函数根据位置进行交叉,如位置1采用sin,位置2采用cos,位置3采用sin,位置4采用cos,以此类推。如下图:
若batch_size是bs,输入词的个数是seq_len,词嵌入维度是n_model。
class PositionalEncoding(nn.Module):
"""
compute sinusoid encoding.
"""
def __init__(self, d_model, max_len, device):
"""
constructor of sinusoid encoding class
:param d_model: dimension of model,词嵌入维度
:param max_len: max sequence length,最大语句的长度
:param device: hardware device setting
"""
super(PositionalEncoding, self).__init__()
# same size with input matrix (for adding with input matrix)
self.encoding = torch.zeros(max_len, d_model, device=device)
self.encoding.requires_grad = False # we don't need to compute gradient
pos = torch.arange(0, max_len, device=device)
pos = pos.float().unsqueeze(dim=1)
# 1D => 2D unsqueeze to represent word's position
_2i = torch.arange(0, d_model, step=2, device=device).float()
# 'i' means index of d_model (e.g. embedding size = 50, 'i' = [0,50])
# "step=2" means 'i' multiplied with two (same with 2 * i)
# 位置编码
self.encoding[:, 0::2] = torch.sin(pos / (10000 ** (_2i / d_model)))
self.encoding[:, 1::2] = torch.cos(pos / (10000 ** (_2i / d_model)))
# compute positional encoding to consider positional information of words
因此输入的维度x:tensor(bs,seq_len)
PositionalEncoding(x),输出的维度是:tensor(seq_len,d_model)
将词嵌入模块与位置模块相加得到Encoder模块的输入。
维度表示这一过程:
tensor(bs,seq_len,n_model) + tensor(seq_len,d_model) = tensor(bs,seg_len,n_model)
3. 多头注意力机制模块
多头注意力机制模块是transformer的核心,是encoder和decoder的重要组成部分。
多头注意力机制模块通过多个自注意力机制拼接,然后通过全连接层得到输出,理解自注意力机制就能很好的理解多头注意力机制。
3.1 自注意力机制模块
自注意力机制模块通过WQ,WK,WV矩阵得到Q,K,V矩阵。如下图:
然后矩阵Q与矩阵K的转置相乘得到每个词间的相关系数,并通过词嵌入维度和softmax对相关系数进行归一化。
最后乘V矩阵,得到自注意力机制的输出,
向量维度表示这一过程:
输入x:tensor(bs,seq_len,n_model)
WQ,WK,WV:tensor(n_model,n_model)
x@WQ、x@WK和x@WV得到Q,K,V的矩阵:tensor(bs,seq_len,n_model),其中@表示矩阵内即
Q@K^T:tensor(bs,seq_len,seq_len)
coef=softmax(Q@K^T)/sqrt(n_model):tensor(bs,seq_len,seq_len)
coef@V:tensor(bs,seq_len,n_model)
3.2 多头注意力机制
多头注意力机制,顾名思义,包含多个自注意力机制,然后将多个自注意力机制的输出进行拼接,最后通过全连接层得到输出。
如下图,输入x通过多个自注意力机制得到多个Q,K,V。
然后按照上节描述的那样,得到多个Z。
对多个注意力机制的输出Z进行拼接:
最后喂入全连接层,得到最终的结果。
向量维度表示这一过程:
多头:n_head
输入x:tensor(bs,seq_len,n_model)
n_head的自注意力机制模块的输出:tensor(bs,seq_len,n_model)
拼接:(bs,seq_len,n_model*n_head)
全连接层权重:tensor(n_model*n_head,n_model)
喂入全连接层后的输出:tensor(bs,seq_len,n_model)
我们通过矩阵思想实现多头注意力机制模块,矩阵每行表示某词的嵌入向量,行数表示词个数,列数表示嵌入维度。
3.3 为什么使用自注意力机制模块
这个机制使得模型能够自动确定输入序列中的不同元素对生成特定输出的重要性,这种权重分配方法允许模型更好地理解序列中的上下文关系,同时处理顺序无关的数据。
4. 层归一化模块
层归一化模块用于调整数据范围,加速模型的训练,LayerNorm在每个样本上统计所有维度的值,计算均值和方差,这样的做的优点是可以不受样本数的限制。
向量维度表示这一过程:
x: tensor(bs,seq_len,n_model)
LayerNorm(x): tensor(bs,seq_len,n_model)
5. 残差模块
这个比较简单,相信只要了解深度学习都知道resnet网络。
向量维度表示这一过程:
x: tensor(bs,seq_len,n_model)
selfAttentions(x): tensor(bs,seq_len,n_model)
y: tensor(bs,seq_len,n_model)
层归一化LayerNorm(y): tensor(bs,seq_len,n_model)
6. 前馈神经网络模块
FFN模块通过特征变换和维度扩展,使得模型能够更好地理解和表示输入序列的信息,从而提高了自然语言处理任务的性能,模块通过两个全连接层实现。
self.linear1 = nn.Linear(d_model, hidden)
self.linear2 = nn.Linear(hidden, d_model)
向量维度表示这一过程:
输入x: tensor(bs,seq_len,n_model)
x=linear1(x): tensor(bs,seq_len,hidden)
linear2(x): tensor(bs,seq_len,n_model)
7. 交叉多头注意力机制模块
交叉多头注意力模块位置如下标注的红色矩形框:
除了Q,K,V的含义不一样,交叉多头注意力模块和多头注意力机制非常相似,交叉注意力机制计算两个序列的注意力,用于处理两个序列之间的语义关系,论文中是计算输入序列和输出序列的注意力,其中K和V来源于输入序列,Q来源于输出序列。多头注意力机制是计算单个序列的注意力,Q,K,V都来源于同一个序列。
向量维度表示这一过程:
q: tensor(bs,encoder_seq_len,d_model)
v: tensor(bs,encoder_seq_len,d_model)
k: tensor(bs,decoder_seq_len,d_model)
其中bs,encoder_seq_len,decoder_seq_len,d_model分别为样本个数,输入语句的长度,输出语句的长度,词嵌入维度
x = multi_head_attention(q,k,v) # 多头注意力模块
x的维度:tensor(bs,decoder_seq_len,n_model)
8. 掩码多头注意力机制模块
掩码多头注意力机制模块位置如下红色矩形框:
掩码的目的是为了防止网络看到不该看的内容,输出语句是有前后关系的,Transformer推理时,我们是一个词一个词的输出,但在训练时这样做效率太低了。我们还是希望利用向量的思想囊括所有的输出单词,因此我们会将target一次性给到Transformer,利用掩码选择计算注意力机制的单词序列。
参考链接:http://jalammar.github.io/illustrated-transformer/