1. 从简单的开始: self attn的实现
# import torch.nn
from torch import nn
from math import sqrt
# torch.nn.TransformerEncoder
class Self_Attention(nn.Module):
def __init__(self, model_dim): # X: batch_size * seq_len * input_dim
'''
:param dim: transformer 内变量维度
'''
super(Self_Attention,self).__init__() # Wk/ Wq/ Wv: input_dim * dim_k
self.model_dim = model_dim
self.q = nn.Linear(model_dim, model_dim)
self.k = nn.Linear(model_dim, model_dim)
self.v = nn.Linear(model_dim, model_dim)
self.softmax = nn.Softmax(dim=-1)
# self.norm = 1/ sqrt(dim_k)
def forwad(self, X):
Q, K, V = self.q(X), self.k(X), self.v(X) # Q, K, V: batch_size * seq_len * dim_k
Attn = self.softmax(torch.bmm(Q, K.transpose(2, 1))/sqrt(self.model_dim))
O = torch.bmm(Attn, V) # O: batch_size * seq_len * dim_k
return O
2. multihead attention
import torch
from torch import nn
from math import sqrt
import numpy as np
import copy
class dot_attention(nn.Module):
"""
点积注意力机制
"""
def __init__(self, attention_dropout=0.0): # 块内 dropout; 一般不设置
super(dot_attention, self).__init__()
self.softmax = nn.Softmax(dim=-1)
self.dropout = nn.Dropout(attention_dropout)
def forward(self, Q, K, V, attn_mask=None):
scale = Q.shape[-1]
Attn = torch.bmm(Q, K.transpose(2,1))/sqrt(scale)
if attn_mask:
Attn = Attn.masked_fill(attn_mask, -np.inf)
Attn = self.softmax(Attn)
Attn = self.dropout(Attn)
O = torch.bmm(Attn, V)
return Attn, O
class MultiHeadAttention(nn.Module):
"""
点积注意力机制
"""
def __init__(self, model_dim, num_heads, dropout= 0.0):
self.dim_per_head = model_dim// num_heads
self.num_heads = num_heads
self.model_dim = model_dim
self.k = nn.Linear(model_dim, self.dim_per_head*num_heads)
self.q = nn.Linear(model_dim, self.dim_per_head*num_heads)
self.v = nn.Linear(model_dim, self.dim_per_head*num_heads)
self.dot_attn = dot_attention(dropout)
self.linear_final = nn.Linear(model_dim, model_dim)
self.dropout = nn.Dropout(dropout)
self.norm = nn.LayerNorm(model_dim)
def forward(self, X, attn_mask=None): # X: batch_size * seq_len * dim
residual = X
batch_size = X.shape[0]
if attn_mask:
attn_mask = attn_mask.repeat(self.num_heads, 1, 1)
K, Q, V = self.k(X), self.q(X), self.v(X)
Q = Q.view(batch_size*self.num_heads, -1, self.dim_per_head)
K = K.view(batch_size*self.num_heads, -1, self.dim_per_head)
V = V.view(batch_size*self.num_heads, -1, self.dim_per_head)
Attn, O = self.dot_attn(Q, K, V, attn_mask)
O.view(batch_size, -1, self.model_dim)
O = self.linear_final(O)
O = self.dropout(O)
O = self.norm(O+residual)
return O
if __name__ == '__main__':
Q = torch.rand((1,5,512))
K = torch.rand((1,10,512))
V = copy.deepcopy(K)
cal_attn = dot_attention()
Attn, O = cal_attn(Q, K, V)
print("Attn:", Attn)
print("O:", O)
3. 无 attn_mask 无 dropout 的简化版
class MultiHeadAttention(nn.Module):
def __init__(self, model_dim, head_nums):
super().__init__(MultiHeadAttention, self)
assert model_dim % head_nums == 0
self.model_dim = model_dim
self.head_nums = head_nums
self.dim_per_head = model_dim / head_nums
self.linear_k = nn.Linear(model_dim, model_dim)
self.linear_q = nn.Linear(model_dim, model_dim)
self.linear_v = nn.Linear(model_dim, model_dim)
self.softmax = nn.Softmax(-1)
self.fc = nn.Linear(model_dim, model_dim)
def forward(self, X): # X: batch_size * seq_len * dim
bsz = X.shape[0]
K, Q, V = self.linear_k(X), self.linear_q(X), self.linear_v(X)
K = K.view(bsz*self.head_nums, -1, self.dim_per_head)
Q = Q.view(bsz*self.head_nums, -1, self.dim_per_head)
V = V.view(bsz*self.head_nums, -1, self.dim_per_head)
Attn = self.softmax(torch.bmm(Q, K.transpse(2,1))/sqrt(self.dim_per_head))
O = torch.bmm(Attn, V)
O = O.view(bsz, -1, self.model_dim)
O = self.fc(O)
return O
4.浅贴一个文心一言的 multihead attn的输出代码
import torch
import torch.nn as nn
class MultiHeadSelfAttention(nn.Module):
def __init__(self, d_model, num_heads):
super(MultiHeadSelfAttention, self).__init__()
assert d_model % num_heads == 0
self.d_model = d_model
self.num_heads = num_heads
self.head_dim = d_model // num_heads
self.query_linear = nn.Linear(d_model, d_model)
self.key_linear = nn.Linear(d_model, d_model)
self.value_linear = nn.Linear(d_model, d_model)
self.fc_out = nn.Linear(d_model, d_model)
self.attention_weights = None
def split_heads(self, x, batch_size):
x = x.view(batch_size, -1, self.num_heads, self.head_dim)
return x.transpose(2, 1)
def combine_heads(self, x):
x = x.transpose(2, 1).contiguous().view(x.size(0), -1)
return x
def forward(self, inputs):
batch_size = inputs.size(0)
q = self.query_linear(inputs)
k = self.key_linear(inputs)
v = self.value_linear(inputs)
q = self.split_heads(q, batch_size)
k = self.split_heads(k, batch_size)
v = self.split_heads(v, batch_size)
scores = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))
self.attention_weights = nn.Softmax(dim=-1)(scores)
attended = torch.matmul(self.attention_weights, v)
attended = self.combine_heads(attended)
output = self.fc_out(attended)
return output