根据自己对transformer的理解写的一个代码,使用时只需实例化Positional_Encoding类与Encoder类,使用多层Encoder时,可以设置Encoder循环的次数。
# coding = utf-8
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
# Scaled dot-product attention
class Scaled_Dot_Product_Attention(nn.Module):
'''Scaled Dot-Product Attention '''
def __init__(self):
super(Scaled_Dot_Product_Attention, self).__init__()
def forward(self, Q, K, V, scale=None):
'''
Args:
Q: [batch_size, len_Q, dim_Q]
K: [batch_size, len_K, dim_K]
V: [batch_size, len_V, dim_V]
scale: 缩放因子 论文为根号dim_K
Q, K, V 维度相同
Return:
self-attention后的张量,以及attention张量
'''
# attention = torch.matmul(Q, K.transpose(-1, -2))
attention = torch.matmul(Q, K.permute(0, 2, 1))
if scale:
attention = attention * scale
# if mask: # TODO change this
# attention = attention.masked_fill_(mask == 0, -1e9)
# attention = F.softmax(attention, dim=2)
attention = F.softmax(attention, dim=-1)
context = torch.matmul(attention, V)
return context
# multiHeadAttention
class MultiHeadAttention(nn.Module):
def __init__(self, model_dim, n_heads, dropout=0.0):
"""
:param model_dim: the model dimension
:param n_heads: head numbers
:param dropout: dropout
"""
su