点积注意力的实现方法
import torch
import torch.nn as nn
import numpy as np
class dot_attention(nn.Module):
""" 点积注意力机制"""
def __init__(self, attention_dropout=0.0):
super(dot_attention, self).__init__()
self.dropout = nn.Dropout(attention_dropout)
self.softmax = nn.Softmax(dim=2)
def forward(self, q, k, v, scale=None, attn_mask=None):
"""
前向传播
:param q:
:param k:
:param v:
:param scale:
:param attn_mask:
:return: 上下文张量和attention张量。
"""
attention = torch.bmm(q, k.transpose(1, 2))
if scale:
attention = attention * scale # 是否设置缩放
if attn_mask:</