点积注意力的实现方法
import torch
import torch.nn as nn
import numpy as np
class dot_attention(nn.Module):
""" 点积注意力机制"""
def __init__(self, attention_dropout=0.0):
super(dot_attention, self).__init__()
self.dropout = nn.Dropout(attention_dropout)
self.softmax = nn.Softmax(dim=2)
def forward(self, q, k, v, scale=None, attn_mask=None):
"""
前向传播
:param q:
:param k:
:param v:
:param scale:
:param attn_mask:
:return: 上下文张量和attention张量。
"""
attention = torch.bmm(q, k.transpose(1, 2))
if scale:
attention = attention * scale # 是否设置缩放
if attn_mask:
attention = attention.masked_fill(attn_mask, -np.inf) # 给需要mask的地方设置一个负无穷。
# 计算softmax
attention = self.softmax(attention)
# 添加dropout
attention = self.dropout(attention)
# 和v做点积。
context = torch.bmm(attention, v)
return context, attention
if __name__ == '__main__':
q = torch.ones((1, 2, 512))
k = torch.ones((1, 17, 512))
v = k
attention = dot_attention()
context, attention = attention(q, k, v)
print("context:", context.size(), context)
print("attention:", attention)
mutil_head attention 实现方法
import torch
import torch.nn as nn
import numpy as np
from dot_attention import dot_attention
class MultiHeadAttention(nn.Module):
""" 多头自注意力"""
def __init__(self, model_dim=400, num_heads=4, dropout=0.0):
super(MultiHeadAttention, self).__init__()
self.dim_per_head = model_dim
self.num_heads = num_heads
self.linear_k = nn.Linear(model_dim, self.dim_per_head * num_heads)
self.linear_v = nn.Linear(model_dim, self.dim_per_head * num_heads)
self.linear_q = nn.Linear(model_dim, self.dim_per_head * num_heads)
self.dot_product_attention = dot_attention(dropout)
self.linear_final = nn.Linear(model_dim, model_dim)
self.dropout = nn.Dropout(dropout)
self.layer_norm = nn.LayerNorm(model_dim) # LayerNorm 归一化。
def forward(self, key, value, query, attn_mask=None):
# 残差连接
residual = query
dim_per_head = self.dim_per_head
num_heads = self.num_heads
batch_size = key.size(0)
# 线性映射。
key = self.linear_k(key)
value = self.linear_v(value)
query = self.linear_q(query)
# 按照头进行分割
key = key.view(batch_size * num_heads, -1, dim_per_head)
value = value.view(batch_size * num_heads, -1, dim_per_head)
query = query.view(batch_size * num_heads, -1, dim_per_head)
if attn_mask:
attn_mask = attn_mask.repeat(num_heads, 1, 1)
# 缩放点击注意力机制
scale = (key.size(-1)
context, attention = self.dot_product_attention(query, key, value, scale, attn_mask)
# 进行头合并 concat heads
context = context.view(batch_size, -1, dim_per_head * num_heads)
# 进行线性映射
output = self.linear_final(context)
# dropout
output = self.dropout(output)
# 添加残差层和正则化层。
output = self.layer_norm(residual + output)
return output, attention
if __name__ == '__main__':
q = torch.ones((1, 17, 400))
k = torch.ones((1, 17, 400))
v = k
mutil_head_attention = MultiHeadAttention()
output, attention = mutil_head_attention(q, k, v)
print("context:", output.size(), output)
print("attention:", attention.size(), attention)