文章目录
前言
目前,我研究大模型相关知识,常用到transformer结构,我想到NLP领域开篇之作Attention is all you need论文,论文实际提出transform结构,可与CNN并驾齐驱的结构,该结构利用Q/K/V模式整合全局信息,与CNN提取局部信息有所差别。介于此,我将一年前博客园更新笔记迁入该博客中,本文将介绍transform原理,也根据源码解读,深入介绍transforme经典典结构,并附有代码。
论文链接:点击这里
一、Transformer结构的原理
该部分主要介绍Attention is all you need 结构、模块、公式。暂时不介绍什么Q K V 什么Attention 什么编解码等,后面我将会根据代码解读介绍,让读者更容易理解。
1、Transform结构
Transformer由且仅由Attention和Feed Forward Neural Network(也称FFN)组成,其中Attention包含self Attention与Mutil-Head Attention,如下图:
注:模型一般可有encode与decode组成,encode负责特征编码,decode负责解码。目前,也有论文不使用解码器decode,如swin-transform。
2、位置编码公式
位置编码公式(还有很多其它公式,该论文使用此公式),如下:
3、transformer公式
4、FFN结构
FFN是由nn.Linear线性和激活函数构成,后面代码详细说明。
二、Encode模块代码解读
1、编码数据
编码输入数据介绍:
enc_input = [
[1, 3, 4, 1, 2, 3],
[1, 3, 4, 1, 2, 3],
[1, 3, 4, 1, 2, 3],
[1, 3, 4, 1, 2, 3]]
编码使用输入数据,为4x6行,表示4个句子,每个句子有6个单词,包含标点符号。
注:至于文本如何表示数字,可参考
这里
2、文本Embedding编码
文本嵌入embedding:
self.src_emb = nn.Embedding(vocab_size, d_model) # d_model=128
vocab_size:词典的大小尺寸,比如总共出现5000个词,那就输入5000。此时index为(0-4999)
d_model:嵌入向量的维度,即用多少维来表示一个词或符号
nn.Embedding()函数可使用torch调用,建议读者百度了解其功能。
随后可将输入x=enc_input,可将enc_outputs则表示嵌入成功,维度为[4,6,128]分别表示batch为4,词为6,用128维度描述词6
x = self.src_emb(x) # 词嵌入
3、位置position编码
位置编码,使用上面公式嵌入,我将不再介绍,其代码如下:
pe = torch.zeros(max_len, d_model)
position = torch.arange(0., max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0., d_model, 2) * -(math.log(10000.0) / d_model)) # 偶数列
pe[:, 0::2] = torch.sin(position * div_term) # 奇数列
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
将编码进行位置编码后,位置为[1,6,128]+输入编码的[4,6,128],相当于句子已经结合了位置编码信息,作为新新的输入,代码如下:
x = x + Variable(self.pe[:, :x.size(1)], requires_grad=False) #torch.autograd.Variable 表示有梯度的张量变量
4、Attention编码
在介绍此之前,先普及一个知识,若X与Y相等,则为self attention 否则为cross-attention,因为解码时候X!=Y.
获取Q K V 代码,实际是一个线性变化,将以上输入x变成[4,6,512],然后通过head个数8与对应dv,dk将512拆分[8,64],随后移维度位置,变成[4,8,6,64]
self.WQ = nn.Linear(d_model, d_k * n_heads) # 利用线性卷积
self.WK = nn.Linear(d_model, d_k * n_heads)
self.WV = nn.Linear(d_model, d_v * n_heads)
变化后的q k v
q_s = self.WQ(Q).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2) # 线性卷积后再分组实现head功能
k_s = self.WK(K).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
v_s = self.WV(V).view(batch_size, -1, self.n_heads, self.d_v).transpose(1, 2)
attn_mask = attn_mask.unsqueeze(1).repeat(1, self.n_heads, 1, 1) # 编导对应的头
随后通过以上self公式,将其编码计算
scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(self.d_k)
attn = nn.Softmax(dim=-1)(scores)
context = torch.matmul(attn, V)
以上编码将是encode编码得到结果,我们将得到结果进行还原:
context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.n_heads * self.d_v) # 将其还原
output = self.linear(context) # 通过线性又将其变成原来模样维度
layer_norm(output + Q) # 这里加Q 实际是对Q寻找
以上将重新得到新的输入x,维度为[4,6,128]
5、FFN编码
将以上的输出维度为[4,6,128]进行FFN层变化,实际类似线性残差网络变化,得到最终输出
class PoswiseFeedForwardNet(nn.Module):
def __init__(self, d_model, d_ff):
super(PoswiseFeedForwardNet, self).__init__()
self.l1 = nn.Linear(d_model, d_ff)
self.l2 = nn.Linear(d_ff, d_model)
self.relu = GELU()
self.layer_norm = nn.LayerNorm(d_model)
def forward(self, inputs):
residual = inputs
output = self.l1(inputs) # 一层线性卷积
output = self.relu(output)
output = self.l2(output) # 一层线性卷积
return self.layer_norm(output + residual)
重复以上顺序编码,即将得到经过FFN变化的输出x,维度为[4,6,128],将其重复步骤③-④,因其编码为6个,可重复5个便是完成相应的编码模块。
三、Decode模块代码解读
1、编码数据
解码输入数据介绍,包含以下数据输入dec_input、enc_input的输入与解码后输出的数据,维度为[4,6,128],而dec_input输入如下:
dec_input = [
[1, 0, 0, 0, 0, 0],
[1, 3, 0, 0, 0, 0],
[1, 3, 4, 0, 0, 0],
[1, 3, 4, 1, 0, 0]]
2、文本Embedding与位置编码
dec_input的Embedding与位置编码,因其与encode的实现方法一致,只需将enc_input使用dec_input取代,得到dec_outputs,因此这里将不在介绍。
3、mask编码
整体编码,代码如下:
def get_attn_pad_mask(seq_q, seq_k, pad_index):
batch_size, len_q = seq_q.size()
batch_size, len_k = seq_k.size()
pad_attn_mask = seq_k.data.eq(pad_index).unsqueeze(1)
pad_attn_mask = torch.as_tensor(pad_attn_mask, dtype=torch.int)
return pad_attn_mask.expand(batch_size, len_q, len_k)
以上代码实际是将dec_input进行处理,实际变成以下数据:
[[0, 1, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1],
[0, 0, 0, 1, 1, 1],
[0, 0, 0, 0, 1, 1]]
将其增添维度为[4,1,6],并将其扩张为[4,6,6]
局部代码编写,实际为上三角矩阵:
[[0. 1. 1. 1. 1. 1.]
[0. 0. 1. 1. 1. 1.]
[0. 0. 0. 1. 1. 1.]
[0. 0. 0. 0. 1. 1.]
[0. 0. 0. 0. 0. 1.]
[0. 0. 0. 0. 0. 0.]]
将以上数据添加维度为[1,6,6],在将扩展变成[4,6,6]
关于整体mask与局部mask编码,我的理解是整体信息为语句4个词6个,根据解码输入编码整体信息,而局部编码是基于一个语句6*6编码信息,将其扩张重复到4个语句,
使其mask获得整体信息与局部信息。
dec_self_attn_pad_mask = get_attn_pad_mask(dec_inputs, dec_inputs, self.pad_index) # 整体编码的mask
dec_self_attn_subsequent_mask = get_attn_subsequent_mask(dec_inputs)
dec_self_attn_mask = torch.gt((dec_self_attn_pad_mask + dec_self_attn_subsequent_mask), 0) # torch.gt(a,b) a>b 则为1否则为0
dec_enc_attn_mask = get_attn_pad_mask(dec_inputs, enc_inputs, self.pad_index)
最终将mask整合,获取dec_self_attn_mask信息,同理dec_enc_attn_mask(维度为解码编码词维度)采用dec_self_attn_mask的第一步便可获取。
4、Attention编码
编码输入self-Attention,包含2部分,self Attention与cross Attention。
self attention
解码输入dec_outputs进行self.Attention:
实际使用以上Q K V公式,具体实现和编码实现方法一致,唯一不同是在Q*K^T会使用解码maskdec_self_attn_mask,其重要代码为scores.masked_fill_(attn_mask, -1e9),代码如下:
class ScaledDotProductAttention(nn.Module):
def __init__(self, d_k, device):
super(ScaledDotProductAttention, self).__init__()
self.device = device
self.d_k = d_k
def forward(self, Q, K, V, attn_mask):
scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(self.d_k)
attn_mask = torch.as_tensor(attn_mask, dtype=torch.bool)
attn_mask = attn_mask.to(self.device)
scores.masked_fill_(attn_mask, -1e9) # it is true give -1e9
attn = nn.Softmax(dim=-1)(scores)
context = torch.matmul(attn, V)
return context, attn
以上代码将执行以下代码:
context, attn = ScaledDotProductAttention(d_k=self.d_k, device=self.device)(Q=q_s, K=k_s, V=v_s,
attn_mask=attn_mask)
context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.n_heads * self.d_v) # 将其还原
output = self.linear(context) # 通过线性又将其变成原来模样维度
dec_outputs = self.layer_norm(output + Q) # 这里加Q 实际是对Q寻找
到此为止已经完成了解码输入的self-attention模块,输出为dec_outputs实际除了增加mask编码调整Q*K^T以外,其它完全相同。
cross attention
编码输出dec_outputs进行Cross Attention:
dec_outputs, dec_enc_attn = self.dec_enc_attn(dec_outputs, enc_outputs, enc_outputs, dec_enc_attn_mask)
重点说明enc_outputs来源编码结果,是一直不变的,以上为Cross Attention 过程,以上代码除了Q来源dec_outputs,K V 来源编码输出enc_outputs以外,即论文所说X与Y不等得到的Q K V称为Cross Attention。
实际以上代码与执行解码self-Attention方法完全一致,仅仅mask更改上文提供的方法,得到输出结果为dec_outputs,因此这里将不在解释了。
5、FFN编码
该部分编码与encode的FFN一样,我将不在解释。
重复步骤上面4与5为n次,便实现解码过程。
四、源码附件(源码有注释)
最后,我给出attention is all you need的所有代码,只需简单环境便可使用,整体实现代码如下:
import json
import math
import torch
import torchvision
import torch.nn as nn
import numpy as np
from pdb import set_trace
from torch.autograd import Variable
def get_attn_pad_mask(seq_q, seq_k, pad_index):
batch_size, len_q = seq_q.size()
batch_size, len_k = seq_k.size()
pad_attn_mask = seq_k.data.eq(pad_index).unsqueeze(1)
pad_attn_mask = torch.as_tensor(pad_attn_mask, dtype=torch.int)
return pad_attn_mask.expand(batch_size, len_q, len_k)
def get_attn_subsequent_mask(seq):
attn_shape = [seq.size(0), seq.size(1), seq.size(1)]
subsequent_mask = np.triu(np.ones(attn_shape), k=1)
subsequent_mask = torch.from_numpy(subsequent_mask).int()
return subsequent_mask
class GELU(nn.Module):
def forward(self, x):
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
class PositionalEncoding(nn.Module):
"Implement the PE function."
def __init__(self, d_model, dropout, max_len=5000): #
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
# Compute the positional encodings once in log space.
pe = torch.zeros(max_len, d_model)
position = torch.arange(0., max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0., d_model, 2) * -(math.log(10000.0) / d_model)) # 偶数列
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe) # 将变量pe保存到内存中,不计算梯度
def forward(self, x):
x = x + Variable(self.pe[:, :x.size(1)], requires_grad=False) # torch.autograd.Variable 表示有梯度的张量变量
return self.dropout(x)
class ScaledDotProductAttention(nn.Module):
def __init__(self, d_k, device):
super(ScaledDotProductAttention, self).__init__()
self.device = device
self.d_k = d_k
def forward(self, Q, K, V, attn_mask):
scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(self.d_k)
attn_mask = torch.as_tensor(attn_mask, dtype=torch.bool)
attn_mask = attn_mask.to(self.device)
scores.masked_fill_(attn_mask, -1e9) # it is true give -1e9
attn = nn.Softmax(dim=-1)(scores)
context = torch.matmul(attn, V)
return context, attn
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, d_k, d_v, n_heads, device):
super(MultiHeadAttention, self).__init__()
self.WQ = nn.Linear(d_model, d_k * n_heads) # 利用线性卷积
self.WK = nn.Linear(d_model, d_k * n_heads)
self.WV = nn.Linear(d_model, d_v * n_heads)
self.linear = nn.Linear(n_heads * d_v, d_model)
self.layer_norm = nn.LayerNorm(d_model)
self.device = device
self.d_model = d_model
self.d_k = d_k
self.d_v = d_v
self.n_heads = n_heads
def forward(self, Q, K, V, attn_mask):
batch_size = Q.shape[0]
q_s = self.WQ(Q).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2) # 线性卷积后再分组实现head功能
k_s = self.WK(K).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
v_s = self.WV(V).view(batch_size, -1, self.n_heads, self.d_v).transpose(1, 2)
attn_mask = attn_mask.unsqueeze(1).repeat(1, self.n_heads, 1, 1) # 编导对应的头
context, attn = ScaledDotProductAttention(d_k=self.d_k, device=self.device)(Q=q_s, K=k_s, V=v_s,
attn_mask=attn_mask)
context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.n_heads * self.d_v) # 将其还原
output = self.linear(context) # 通过线性又将其变成原来模样维度
return self.layer_norm(output + Q), attn # 这里加Q 实际是对Q寻找
class PoswiseFeedForwardNet(nn.Module):
def __init__(self, d_model, d_ff):
super(PoswiseFeedForwardNet, self).__init__()
self.l1 = nn.Linear(d_model, d_ff)
self.l2 = nn.Linear(d_ff, d_model)
self.relu = GELU()
self.layer_norm = nn.LayerNorm(d_model)
def forward(self, inputs):
residual = inputs
output = self.l1(inputs) # 一层线性卷积
output = self.relu(output)
output = self.l2(output) # 一层线性卷积
return self.layer_norm(output + residual)
class EncoderLayer(nn.Module):
def __init__(self, d_model, d_ff, d_k, d_v, n_heads, device):
super(EncoderLayer, self).__init__()
self.enc_self_attn = MultiHeadAttention(d_model=d_model, d_k=d_k, d_v=d_v, n_heads=n_heads, device=device)
self.pos_ffn = PoswiseFeedForwardNet(d_model=d_model, d_ff=d_ff)
def forward(self, enc_inputs, enc_self_attn_mask):
enc_outputs, attn = self.enc_self_attn(Q=enc_inputs, K=enc_inputs, V=enc_inputs, attn_mask=enc_self_attn_mask)
# X=Y 因此Q K V相等
enc_outputs = self.pos_ffn(enc_outputs) #
return enc_outputs, attn
class Encoder(nn.Module):
def __init__(self, vocab_size, d_model, d_ff, d_k, d_v, n_heads, n_layers, pad_index, device):
# 4 128 256 64 64 8 4 0
super(Encoder, self).__init__()
self.device = device
self.pad_index = pad_index
self.src_emb = nn.Embedding(vocab_size, d_model)
# vocab_size:词典的大小尺寸,比如总共出现5000个词,那就输入5000。此时index为(0-4999) d_model:嵌入向量的维度,即用多少维来表示一个符号
self.pos_emb = PositionalEncoding(d_model=d_model, dropout=0)
self.layers = []
for _ in range(n_layers):
encoder_layer = EncoderLayer(d_model=d_model, d_ff=d_ff, d_k=d_k, d_v=d_v, n_heads=n_heads, device=device)
self.layers.append(encoder_layer)
self.layers = nn.ModuleList(self.layers)
def forward(self, x):
enc_outputs = self.src_emb(x) # 词嵌入
enc_outputs = self.pos_emb(enc_outputs) # pos+matx
enc_self_attn_mask = get_attn_pad_mask(x, x, self.pad_index)
enc_self_attns = []
for layer in self.layers:
enc_outputs, enc_self_attn = layer(enc_outputs, enc_self_attn_mask)
enc_self_attns.append(enc_self_attn)
enc_self_attns = torch.stack(enc_self_attns)
enc_self_attns = enc_self_attns.permute([1, 0, 2, 3, 4])
return enc_outputs, enc_self_attns
class DecoderLayer(nn.Module):
def __init__(self, d_model, d_ff, d_k, d_v, n_heads, device):
super(DecoderLayer, self).__init__()
self.dec_self_attn = MultiHeadAttention(d_model=d_model, d_k=d_k, d_v=d_v, n_heads=n_heads, device=device)
self.dec_enc_attn = MultiHeadAttention(d_model=d_model, d_k=d_k, d_v=d_v, n_heads=n_heads, device=device)
self.pos_ffn = PoswiseFeedForwardNet(d_model=d_model, d_ff=d_ff)
def forward(self, dec_inputs, enc_outputs, dec_self_attn_mask, dec_enc_attn_mask):
dec_outputs, dec_self_attn = self.dec_self_attn(dec_inputs, dec_inputs, dec_inputs, dec_self_attn_mask)
dec_outputs, dec_enc_attn = self.dec_enc_attn(dec_outputs, enc_outputs, enc_outputs, dec_enc_attn_mask)
dec_outputs = self.pos_ffn(dec_outputs)
return dec_outputs, dec_self_attn, dec_enc_attn
class Decoder(nn.Module):
def __init__(self, vocab_size, d_model, d_ff, d_k, d_v, n_heads, n_layers, pad_index, device):
super(Decoder, self).__init__()
self.pad_index = pad_index
self.device = device
self.tgt_emb = nn.Embedding(vocab_size, d_model)
self.pos_emb = PositionalEncoding(d_model=d_model, dropout=0)
self.layers = []
for _ in range(n_layers):
decoder_layer = DecoderLayer(d_model=d_model, d_ff=d_ff, d_k=d_k, d_v=d_v, n_heads=n_heads, device=device)
self.layers.append(decoder_layer)
self.layers = nn.ModuleList(self.layers)
def forward(self, dec_inputs, enc_inputs, enc_outputs):
dec_outputs = self.tgt_emb(dec_inputs)
dec_outputs = self.pos_emb(dec_outputs)
dec_self_attn_pad_mask = get_attn_pad_mask(dec_inputs, dec_inputs, self.pad_index)
dec_self_attn_subsequent_mask = get_attn_subsequent_mask(dec_inputs)
dec_self_attn_mask = torch.gt((dec_self_attn_pad_mask + dec_self_attn_subsequent_mask), 0)
dec_enc_attn_mask = get_attn_pad_mask(dec_inputs, enc_inputs, self.pad_index)
dec_self_attns, dec_enc_attns = [], []
for layer in self.layers:
dec_outputs, dec_self_attn, dec_enc_attn = layer(
dec_inputs=dec_outputs,
enc_outputs=enc_outputs,
dec_self_attn_mask=dec_self_attn_mask,
dec_enc_attn_mask=dec_enc_attn_mask)
dec_self_attns.append(dec_self_attn)
dec_enc_attns.append(dec_enc_attn)
dec_self_attns = torch.stack(dec_self_attns)
dec_enc_attns = torch.stack(dec_enc_attns)
dec_self_attns = dec_self_attns.permute([1, 0, 2, 3, 4])
dec_enc_attns = dec_enc_attns.permute([1, 0, 2, 3, 4])
return dec_outputs, dec_self_attns, dec_enc_attns
class MaskedDecoderLayer(nn.Module):
def __init__(self, d_model, d_ff, d_k, d_v, n_heads, device):
super(MaskedDecoderLayer, self).__init__()
self.dec_self_attn = MultiHeadAttention(d_model=d_model, d_k=d_k, d_v=d_v, n_heads=n_heads, device=device)
self.pos_ffn = PoswiseFeedForwardNet(d_model=d_model, d_ff=d_ff)
def forward(self, dec_inputs, dec_self_attn_mask):
dec_outputs, dec_self_attn = self.dec_self_attn(dec_inputs, dec_inputs, dec_inputs, dec_self_attn_mask)
dec_outputs = self.pos_ffn(dec_outputs)
return dec_outputs, dec_self_attn
class MaskedDecoder(nn.Module):
def __init__(self, vocab_size, d_model, d_ff, d_k,
d_v, n_heads, n_layers, pad_index, device):
super(MaskedDecoder, self).__init__()
self.pad_index = pad_index
self.tgt_emb = nn.Embedding(vocab_size, d_model)
self.pos_emb = PositionalEncoding(d_model=d_model, dropout=0)
self.layers = []
for _ in range(n_layers):
decoder_layer = MaskedDecoderLayer(
d_model=d_model, d_ff=d_ff,
d_k=d_k, d_v=d_v, n_heads=n_heads,
device=device)
self.layers.append(decoder_layer)
self.layers = nn.ModuleList(self.layers)
def forward(self, dec_inputs):
dec_outputs = self.tgt_emb(dec_inputs)
dec_outputs = self.pos_emb(dec_outputs)
dec_self_attn_pad_mask = get_attn_pad_mask(dec_inputs, dec_inputs, self.pad_index)
dec_self_attn_subsequent_mask = get_attn_subsequent_mask(dec_inputs)
dec_self_attn_mask = torch.gt((dec_self_attn_pad_mask + dec_self_attn_subsequent_mask), 0)
dec_self_attns = []
for layer in self.layers:
dec_outputs, dec_self_attn = layer(
dec_inputs=dec_outputs,
dec_self_attn_mask=dec_self_attn_mask)
dec_self_attns.append(dec_self_attn)
dec_self_attns = torch.stack(dec_self_attns)
dec_self_attns = dec_self_attns.permute([1, 0, 2, 3, 4])
return dec_outputs, dec_self_attns
class BertModel(nn.Module):
def __init__(self, vocab_size, d_model, d_ff, d_k, d_v, n_heads, n_layers, pad_index, device):
super(BertModel, self).__init__()
self.tok_embed = nn.Embedding(vocab_size, d_model)
self.pos_embed = PositionalEncoding(d_model=d_model, dropout=0)
self.seg_embed = nn.Embedding(2, d_model)
self.layers = []
for _ in range(n_layers):
encoder_layer = EncoderLayer(
d_model=d_model, d_ff=d_ff,
d_k=d_k, d_v=d_v, n_heads=n_heads,
device=device)
self.layers.append(encoder_layer)
self.layers = nn.ModuleList(self.layers)
self.pad_index = pad_index
self.fc = nn.Linear(d_model, d_model)
self.active1 = nn.Tanh()
self.classifier = nn.Linear(d_model, 2)
self.linear = nn.Linear(d_model, d_model)
self.active2 = GELU()
self.norm = nn.LayerNorm(d_model)
self.decoder = nn.Linear(d_model, vocab_size, bias=False)
self.decoder.weight = self.tok_embed.weight
self.decoder_bias = nn.Parameter(torch.zeros(vocab_size))
def forward(self, input_ids, segment_ids, masked_pos):
output = self.tok_embed(input_ids) + self.seg_embed(segment_ids)
output = self.pos_embed(output)
enc_self_attn_mask = get_attn_pad_mask(input_ids, input_ids, self.pad_index)
for layer in self.layers:
output, enc_self_attn = layer(output, enc_self_attn_mask)
h_pooled = self.active1(self.fc(output[:, 0]))
logits_clsf = self.classifier(h_pooled)
masked_pos = masked_pos[:, :, None].expand(-1, -1, output.size(-1))
h_masked = torch.gather(output, 1, masked_pos)
h_masked = self.norm(self.active2(self.linear(h_masked)))
logits_lm = self.decoder(h_masked) + self.decoder_bias
return logits_lm, logits_clsf, output
class GPTModel(nn.Module):
def __init__(self, vocab_size, d_model, d_ff,
d_k, d_v, n_heads, n_layers, pad_index,
device):
super(GPTModel, self).__init__()
self.decoder = MaskedDecoder(
vocab_size=vocab_size,
d_model=d_model, d_ff=d_ff,
d_k=d_k, d_v=d_v, n_heads=n_heads,
n_layers=n_layers, pad_index=pad_index,
device=device)
self.projection = nn.Linear(d_model, vocab_size, bias=False)
def forward(self, dec_inputs):
dec_outputs, dec_self_attns = self.decoder(dec_inputs)
dec_logits = self.projection(dec_outputs)
return dec_logits, dec_self_attns
class Classifier(nn.Module):
def __init__(self, vocab_size, d_model, d_ff,
d_k, d_v, n_heads, n_layers,
pad_index, device, num_classes):
super(Classifier, self).__init__()
self.encoder = Encoder(
vocab_size=vocab_size,
d_model=d_model, d_ff=d_ff,
d_k=d_k, d_v=d_v, n_heads=n_heads,
n_layers=n_layers, pad_index=pad_index,
device=device)
self.projection = nn.Linear(d_model, num_classes)
def forward(self, enc_inputs):
enc_outputs, enc_self_attns = self.encoder(enc_inputs)
mean_enc_outputs = torch.mean(enc_outputs, dim=1)
logits = self.projection(mean_enc_outputs)
return logits, enc_self_attns
class Translation(nn.Module):
def __init__(self, src_vocab_size, tgt_vocab_size, d_model,
d_ff, d_k, d_v, n_heads, n_layers, src_pad_index,
tgt_pad_index, device):
super(Translation, self).__init__()
self.encoder = Encoder(
vocab_size=src_vocab_size, # 5
d_model=d_model, d_ff=d_ff, # 128 256
d_k=d_k, d_v=d_v, n_heads=n_heads, # 64 64 8
n_layers=n_layers, pad_index=src_pad_index, # 4 0
device=device)
self.decoder = Decoder(
vocab_size=tgt_vocab_size, # 5
d_model=d_model, d_ff=d_ff, # 128 256
d_k=d_k, d_v=d_v, n_heads=n_heads, # 64 64 8
n_layers=n_layers, pad_index=tgt_pad_index, # 4 0
device=device)
self.projection = nn.Linear(d_model, tgt_vocab_size, bias=False)
# def forward(self, enc_inputs, dec_inputs, decode_lengths):
# enc_outputs, enc_self_attns = self.encoder(enc_inputs)
# dec_outputs, dec_self_attns, dec_enc_attns = self.decoder(dec_inputs, enc_inputs, enc_outputs)
# dec_logits = self.projection(dec_outputs)
# return dec_logits, enc_self_attns, dec_self_attns, dec_enc_attns, decode_lengths
def forward(self, enc_inputs, dec_inputs):
enc_outputs, enc_self_attns = self.encoder(enc_inputs)
dec_outputs, dec_self_attns, dec_enc_attns = self.decoder(dec_inputs, enc_inputs, enc_outputs)
dec_logits = self.projection(dec_outputs)
return dec_logits, enc_self_attns, dec_self_attns, dec_enc_attns
if __name__ == '__main__':
enc_input = [
[1, 3, 4, 1, 2, 3],
[1, 3, 4, 1, 2, 3],
[1, 3, 4, 1, 2, 3],
[1, 3, 4, 1, 2, 3]]
dec_input = [
[1, 0, 0, 0, 0, 0],
[1, 3, 0, 0, 0, 0],
[1, 3, 4, 0, 0, 0],
[1, 3, 4, 1, 0, 0]]
enc_input = torch.as_tensor(enc_input, dtype=torch.long).to(torch.device('cpu'))
dec_input = torch.as_tensor(dec_input, dtype=torch.long).to(torch.device('cpu'))
model = Translation(
src_vocab_size=5, tgt_vocab_size=5, d_model=128,
d_ff=256, d_k=64, d_v=64, n_heads=8, n_layers=4, src_pad_index=0,
tgt_pad_index=0, device=torch.device('cpu'))
logits, _, _, _ = model(enc_input, dec_input)
print(logits)
总结
本文已全部介绍完transformer结构原理及代码,但我个人有以下几点说明:
编码传递K V 解码传递Q;
self-attention 和 cross attention本质是X与Y值不同,即得到Q 和 K V 数据来源不同,但实现方法一致;
transformer重点模块为attention(一般是mutil-head attention)、FFN、位置编码、mask编码;