Self-attention 自注意力机制
李宏毅《深度学习》- Self-attention 自注意力机制
理解QKV
import torch
from torch.nn.functional import softmax
#Step 1:准备输入
x = [
[1, 0, 1, 0], # Input 1
[0, 2, 0, 2], # Input 2
[1, 1, 1, 1] # Input 3
]
x = torch.tensor(x, dtype=torch.float32)
#Step 2:初始化权重
w_key = [
[0, 0, 1],
[1, 1, 0],
[0, 1, 0],
[1, 1, 0]
]
w_query = [
[1, 0, 1],
[1, 0, 0],
[0, 0, 1],
[0, 1, 1]
]
w_value = [
[0, 2, 0],
[0, 3, 0],
[1, 0, 3],
[1, 1, 0]
]
w_key = torch.tensor(w_key, dtype=torch.float32)
w_query = torch.tensor(w_query, dtype=torch.float32)
w_value = torch.tensor(w_value, dtype=torch.float32)
#Step 3:导出key, query and value的表示
keys = x @ w_key #@矩阵乘法
querys = x @ w_query
values = x @ w_value
print(keys)
# tensor([[0., 1., 1.],
# [4., 4., 0.],
# [2., 3., 1.]])
print(querys)
# tensor([[1., 0., 2.],
# [2., 2., 2.],
# [2., 1., 3.]])
print(values)
# tensor([[1., 2., 3.],
# [2., 8., 0.],
# [2., 6., 3.]])
#Step 4: 计算输入的注意力得分(attention scores)
attn_scores = querys @ keys.T
# tensor([[ 2., 4., 4.], # attention scores from Query 1
# [ 4., 16., 12.], # attention scores from Query 2
# [ 4., 12., 10.]]) # attention scores from Query 3
#Step 5: 计算softmax
attn_scores_softmax = softmax(attn_scores, dim=-1) #当dim=-1时, 是对每一维度的行进行softmax运算
# tensor([[6.3379e-02, 4.6831e-01, 4.6831e-01],
# [6.0337e-06, 9.8201e-01, 1.7986e-02],
# [2.9539e-04, 8.8054e-01, 1.1917e-01]])
# For readability, approximate the above as follows
attn_scores_softmax = [
[0.0, 0.5, 0.5],
[0.0, 1.0, 0.0],
[0.0, 0.9, 0.1]
]
attn_scores_softmax = torch.tensor(attn_scores_softmax)
#Step 6: 将attention scores乘以value
#print(values[:,None])
#print(attn_scores_softmax.T[:,:,None])
weighted_values = values[:,None] * attn_scores_softmax.T[:,:,None]
# tensor([[[0.0000, 0.0000, 0.0000],
# [0.0000, 0.0000, 0.0000],
# [0.0000, 0.0000, 0.0000]],
#
# [[1.0000, 4.0000, 0.0000],
# [2.0000, 8.0000, 0.0000],
# [1.8000, 7.2000, 0.0000]],
#
# [[1.0000, 3.0000, 1.5000],
# [0.0000, 0.0000, 0.0000],
# [0.2000, 0.6000, 0.3000]]])
weighted_values
#Step 7: 对加权后的value求和以得到输出
outputs = weighted_values.sum(dim=0) #sum(dim=0),将Z层中的每个对应位上的元素进行了相加
# tensor([[2.0000, 7.0000, 1.5000], # Output 1
# [2.0000, 8.0000, 0.0000], # Output 2
# [2.0000, 7.8000, 0.3000]]) # Output 3
Self-Attention
import torch
import torch.nn as nn
import math
class SelfAttention(nn.Module):
"""
input : batch_size * seq_len * input_dim
q : batch_size * input_dim * dim_k
k : batch_size * input_dim * dim_k
v : batch_size * input_dim * dim_v
"""
def __init__(self, input_dim, dim_k, dim_v):
super().__init__()
self.dim_k = dim_k
self.q = nn.Linear(input_dim, dim_k)
self.k = nn.Linear(input_dim, dim_k)
self.v = nn.Linear(input_dim, dim_v)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x):
Q = self.q(x) # Q: batch_size * seq_len * dim_k
K = self.k(x) # K: batch_size * seq_len * dim_k
V = self.v(x) # V: batch_size * seq_len * dim_v
attention = torch.bmm(self.softmax(torch.bmm(Q, K.permute(0, 2, 1)) / math.sqrt(self.dim_k)), V)
return attention
Multi-Head Self-Attention
class MultiHeadSelfAttention(nn.Module):
"""
input : batch_size * seq_len * input_dim
q : batch_size * input_dim * dim_k
k : batch_size * input_dim * dim_k
v : batch_size * input_dim * dim_v
"""
def __init__(self, input_dim, dim_k, dim_v, nums_head):
super(MultiHeadSelfAttention, self).__init__()
assert dim_k % nums_head == 0
assert dim_v % nums_head == 0
self.dim_k = dim_k
self.dim_v = dim_v
self.q = nn.Linear(input_dim, dim_k)
self.k = nn.Linear(input_dim, dim_k)
self.v = nn.Linear(input_dim, dim_v)
self.nums_head = nums_head
self.softmax = nn.Softmax(dim=-1)
def forward(self, x): #方便多头并行计算
Q = self.q(x).view(-1, x.shape[1], self.nums_head, self.dim_k // self.nums_head).permute(0, 2, 3, 1)
K = self.k(x).view(-1, x.shape[1], self.nums_head, self.dim_k // self.nums_head).permute(0, 2, 3, 1)
V = self.v(x).view(-1, x.shape[1], self.nums_head, self.dim_v // self.nums_head).permute(0, 2, 3, 1)
attention = torch.matmul(self.softmax(torch.matmul(Q, K.permute(0, 1, 3, 2)) / math.sqrt(self.dim_k)),
V).transpose(-2, -1) # [batch_size, n_head, seq_len, hidden_size // n_head]
attention = attention.transpose(1, 2) # [batch_size, seq_len, n_head, hidden_size // n_head]
output = attention.reshape(-1, x.shape[1], x.shape[2]) # [batch_size, seq_len, hidden_size]
# 或
# attention = attention.permute(2, 0, 1, 3)
# output = torch.cat([_ for _ in attention], dim=-1)
return output