import numpy as np
import torch
from torch import nn
from torch.nn import init
# 多头注意力
# 缩放点积计算相似性
# 方法出处 2017 NIPS《Attention Is All You Need 》
class ScaledDotProductAttention(nn.Module):
# 定义网络的层
def __init__(self, d_model, d_k, d_v, h, dropout=.1):
'''
:param d_model:模型的输出维度(其实也是输入维度,可以理解为词嵌入的维度)
:param d_k:查询向量和键向量的维度
:param d_v:值向量的维度
:param h: 多头注意力的头数
'''
# 所有继承于nn.Module的模型都要写这句话
super(ScaledDotProductAttention, self).__init__()
# 定义输入的线性变换层
# d_model代表输入到模型的维度,可以想象成词嵌入的维度
# 分别进行Q,K,V的线性变换
self.fc_q = nn.Linear(d_model, h * d_k)
self.fc_k = nn.Linear(d_model, h * d_k)
self.fc_v = nn.Linear(d_model, h * d_v)
# 最后的输出层,如果以词嵌入矩阵作为输入可以理解为包含了不同词嵌入向量之间的相关性信息
self.fc_o = nn.Linear(h * d_v, d_model)
# dropout层
self.dropout = nn.Dropout(dropout)
# 维度,头数信息
self.d_model = d_model
self.d_k = d_k
self.d_v = d_v
self.h = h
# 初始化模型各层的权重
self.init_weights()
def init_weights(self):
# 遍历模型的所有层
for m in self.modules():
# 如果当前层是卷积层
if isinstance(m, nn.Conv2d):
# kaiming初始化
init.kaiming_normal_(m.weight, mode='fan_out')
# 偏置初始化为0
if m.bias is not None:
init.constant_(m.bias, 0)
# 如果当前层是正则化层
elif isinstance(m, nn.BatchNorm2d):
# 权重初始化为1
init.constant_(m.weight, 1)
# 偏置初始化为0
init.constant_(m.bias, 0)
# 如果当前层是线性层
elif isinstance(m, nn.Linear):
# 标准差为0.001的正态分布初始化
init.normal_(m.weight, std=0.001)
# 偏置初始化为0
if m.bias is not None:
init.constant_(m.bias, 0)
# 前向传递
def forward(self, queries, keys, values, attention_mask=None, attention_weights=None):
'''
Computes
:param queries: Queries (b_s, nq, d_model) 查询矩阵(b_s,nq,d_model)(批大小,查询个数,词嵌入维度)
:param keys: Keys (b_s, nk, d_model) 键矩阵
:param values: Values (b_s, nk, d_model) 值矩阵
:param attention_mask: Mask
:param attention_weights: 注意力权重的权重 (b_s, h, nq, nk).
:return:
'''
# 获取批大小,查询向量的数量
b_s, nq = queries.shape[:2]
# 获取键的数量
nk = keys.shape[1]
# 分别将查询,键,值矩阵进行线性变换
# 以其中的查询向量举例
# 输入的queries维度是(b_s,nq,d_model)
# 输入到fc_q线性变换层
# fc_q线性变换层输出的维度是(b_s,nq,h*d_k)
# 将线性变换层输出的维度调整为(b_s,nq,h,d_k)(分离头数)
# 再调整为(b_s,h,nq,d_k)
# permute()调整张量维度位置函数
q = self.fc_q(queries).view(b_s, nq, self.h, self.d_k).permute(0, 2, 1, 3) # (b_s, h, nq, d_k)
# 注意这块儿的维度变换
# 因为要计算q和k的矩阵乘法
# 这里k相当于做了一个转置
k = self.fc_k(keys).view(b_s, nk, self.h, self.d_k).permute(0, 2, 3, 1) # (b_s, h, d_k, nk)
v = self.fc_v(values).view(b_s, nk, self.h, self.d_v).permute(0, 2, 1, 3) # (b_s, h, nk, d_v)
# 计算注意力得分,缩放点积
att = torch.matmul(q, k) / np.sqrt(self.d_k) # (b_s, h, nq, nk)
# 如果指定了不同注意力得分的权重
if attention_weights is not None:
att = att * attention_weights
# 如果指定了不同注意力得分的mask
# 这种mask应用于自然语言处理中
# 对于当前处理的词向量,
# 一些其它词向量的信息此时是不应该知道的
# 但是计算结果中有
# 要做一个mask
# 遮盖住那些现在本不应该知道的信息
if attention_mask is not None:
# 将要mask的信息置为负无穷
att = att.masked_fill(attention_mask, -np.inf)
# 进行softmax,将注意力得分变成value的权重系数
att = torch.softmax(att, -1)
# 随机删除一些
att = self.dropout(att)
# 对值矩阵加权
# 调整结果维度
# contiguous()是汇集不同内存区域的当前结果到同一片内存区域中
out = torch.matmul(att, v).permute(0, 2, 1, 3).contiguous().view(b_s, nq, self.h * self.d_v) # (b_s, nq, h*d_v)
# 输出维度(b_s, nq, d_model)
out = self.fc_o(out)
return out
# 测试
if __name__ == '__main__':
# 这个input可以想象成是
# 一个50*49*512的词嵌入矩阵
# 其中50代表批大小
# 49代表这句话有49个词
# 512是每个词的嵌入维度
input = torch.randn(50, 49, 512)
# 将这个词嵌入矩阵输入多头注意力中进行编码
sa = ScaledDotProductAttention(d_model=512, d_k=512, d_v=512, h=8)
# 得到这个词嵌入矩阵的多头注意力编码
# output维度还是50*49*512
# 但是矩阵的每一行现在包含了和其它词嵌入向量的相似信息
output = sa(input, input, input)
print(output.shape)
简化的多头注意力
import numpy as np
import torch
from torch import nn
from torch.nn import init
# 简化的多头注意力
class SimplifiedScaledDotProductAttention(nn.Module):
def __init__(self, d_model, h, dropout=.1):
'''
:param d_model:
:param d_k:
:param d_v:
:param h:
'''
super(SimplifiedScaledDotProductAttention, self).__init__()
self.d_model = d_model
self.d_k = d_model // h
self.d_v = d_model // h
self.h = h
self.fc_o = nn.Linear(h * self.d_v, d_model)
self.dropout = nn.Dropout(dropout)
self.init_weights()
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
init.kaiming_normal_(m.weight, mode='fan_out')
if m.bias is not None:
init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
init.constant_(m.weight, 1)
init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
init.normal_(m.weight, std=0.001)
if m.bias is not None:
init.constant_(m.bias, 0)
def forward(self, queries, keys, values, attention_mask=None, attention_weights=None):
'''
Computes
:param queries: Queries (b_s, nq, d_model)
:param keys: Keys (b_s, nk, d_model)
:param values: Values (b_s, nk, d_model)
:param attention_mask:
:param attention_weights:
:return:
'''
b_s, nq = queries.shape[:2]
nk = keys.shape[1]
q = queries.view(b_s, nq, self.h, self.d_k).permute(0, 2, 1, 3) # (b_s, h, nq, d_k)
k = keys.view(b_s, nk, self.h, self.d_k).permute(0, 2, 3, 1) # (b_s, h, d_k, nk)
v = values.view(b_s, nk, self.h, self.d_v).permute(0, 2, 1, 3) # (b_s, h, nk, d_v)
att = torch.matmul(q, k) / np.sqrt(self.d_k) # (b_s, h, nq, nk)
if attention_weights is not None:
att = att * attention_weights
if attention_mask is not None:
att = att.masked_fill(attention_mask, -np.inf)
att = torch.softmax(att, -1)
att = self.dropout(att)
out = torch.matmul(att, v).permute(0, 2, 1, 3).contiguous().view(b_s, nq, self.h * self.d_v) # (b_s, nq, h*d_v)
out = self.fc_o(out) # (b_s, nq, d_model)
return out
if __name__ == '__main__':
input = torch.randn(50, 49, 512)
ssa = SimplifiedScaledDotProductAttention(d_model=512, h=8)
output = ssa(input, input, input)
print(output.shape)