《Attention is all you need》Pytorch实现代码
Self Attention
示意图
![(left) Scaled Dot-Product Attention. (right) Multi-Head Attention](https://img-blog.csdnimg.cn/c9f80dcbedd64664852654791b63164e.png?x-oss-process=image/watermark,type_ZHJvaWRzYW5zZmFsbGJhY2s,shadow_50,text_Q1NETiBAWWVsbG93Q29vbHM=,size_20,color_FFFFFF,t_70,g_se,x_16#pic_center)
代码实现
class SelfAttention(nn.Module):
def __init__(self, embed_size, heads):
"""
:param embed_size: int
:param heads: int
"""
super(SelfAttention, self).__init__()
self.embed_size = embed_size
self.heads = heads
self.head_dim = embed_size // heads
assert (self.head_dim * heads == embed_size), "Embed size needs to be div by heads."
self.values = nn.Linear(self.embed_size, self.head_dim, bias=False)
self.keys = nn.Linear(self.embed_size, self.head_dim, bias=False)
self.queries = nn.Linear(self.embed_size, self.head_dim, bias=False)
self.fc_out = nn.Linear(self.head_dim * heads, embed_size)
def forward(self, values, keys, queries, mask):
"""
:param values: (N,value_len,heads,head_dim)
:param keys: (N,key_len,heads,head_dim)
:param queries: (N,query_len,heads,head_dim)
:param mask: (N,heads,query_len,key_len)
:return out: (N,query_len,heads,head_dim)
"""
N = queries.shape[0]
values_len, key_len, query_len = values.shape[1], keys.shape[1], queries.shape[1]
values = values.reshape(N, values_len, self.heads, self.head_dim)
keys = keys.reshape(N, key_len, self.heads, self.head_dim)
queries = queries.reshape(N, query_len, self.heads, self.head_dim)
energy = torch.einsum("nqhd,nkhd->nhqk", queries, keys)
if mask is not None:
energy = energy.masked_fill(mask == 0, float("-1e20"))
attention = nn.Softmax(dim=3)(energy / (self.embed_size ** (1 / 2)))
out = torch.einsum("nhql,nlhd->nqhd", attention, values).reshape(N,query_len,self.heads*self.head_dim)
out = self.fc_out(out)
return out
注意事项
energy = torch.einsum("nqhd,nkhd->nhqk", queries, keys)