s计算score的方式有以下几种:
import torch
class Attn(nn.module):
def __init__(self, method ,hidden_size):
super(Attn, self).__init__()
self.method = method
self.hidden_size = hidden_size
if self.method not in ['dot', 'general', 'concat']:
raise ValueError(self.method, "is not an appropriate attention method.")
self.hidden_size = hidden_size
if self.method == 'general':
self.attn = nn.Linear(self.hidden_size, hidden_size)
elif self.method == 'concat':
self.attn = nn.Linear(self.hidden_size * 2, hidden_size)
self.v = nn.Parameter(torch.FloatTensor(