1.ScaledDotProductAttention
import torch.nn as nn
import numpy as np
class ScaledDotProductAttention(nn.Module):
def __init__(self,temperature,hidden_dim,attn_dropout = 0.1):
super().__init__()
self.temperature = temperature
self.dropout = nn.Dropout(attn_dropout)
self.softmax = nn.Softmax(dim = 2)
def forward(self,q,k,v,mask = None):
# q,k:[batch_size,max_len,d_k]
# v:[batch_size,max_len,d_v]
# attn:[batch_size,max_len,max_len]
attn = torch.bmm(q,k.transpose(1,2))
attn = attn / self.temperature
if mask is not None:
mask = mask.bool()
attn = attn.masked_fill(mask,-np.inf)
attn = self.softmax(attn)
attn = self.dropout(attn)
# output:[batch_size,max_len,d_v]
output = torch.bmm(attn,v)
return output,attn
temperature = 10
hidden = 100
batch_size = 4
max_len = 18
dim = 300
q = torch.randn([batch_size,max_len,dim])
k = torch.randn([batch_size,max_len,dim])
v = torch.randn([batch_size,max_len,dim])
mask = torch.triu(torch.ones([batch_size,max_len,max_len]),0)
model = ScaledDotProductAttention(temperature,dim)
output,attn = model(q,k,v)
print('attn.shape:',attn.shape)
print('output.shape:',output.shape)
输出:
attn.shape: torch.Size([4, 18, 18])
output.shape: torch.Size([4, 18, 300])
2.MultiHeadAttention
整体:
Multi-Head Attention部分
class MultiHeadAttention(nn.Module):
def __init__(self,n_head,d_model,d_k,d_v,dropout = 0.1):
super().__init__()
self.n_head = n_head
self.d_k = d_k
self.d_v = d_v
self.w_qs = nn.Linear(d_model,n_head * d_k)
self.w_ks = nn.Linear(d_model,n_head * d_k)
self.w_vs = nn.Linear(d_model,n_head * d_v)
nn.init.normal_(self.w_qs.weight,mean = 0,std = np.sqrt(2.0 / (d_model + d_k)))
nn.init.normal_(self.w_ks.weight,mean = 0,std = np.sqrt(2.0 / (d_model + d_k)))
nn.init.normal_(self.w_vs.weight,mean = 0,std = np.sqrt(2.0 / (d_model + d_v)))
self.attention = ScaledDotProductAttention(temperature=np.power(d_k,0.5),hidden_dim = self.d_v)
self.layer_norm = nn.LayerNorm(d_model)
self.fc = nn.Linear(n_head * d_v,d_model)
nn.init.xavier_normal_(self.fc.weight)
self.dropout = nn.Dropout(dropout)
def forward(self,q,k,v,mask = None):
# q:[batch_size,max_len,d_k]
# k:[batch_size,max_len,d_k]
# v:[batch_size,max_len,d_v]
# mask:[batch_size,max_len,max_len]
d_k,d_v,n_head = self.d_k,self.d_v,self.n_head
sz_b,len_q,_ = q.size()
sz_b,len_k,_ = k.size()
sz_b,len_v,_ = v.size()
residual = q
# q:[batch_size,max_len,n_head,d_k]
# k:[batch_size,max_len,n_head,d_k]
# v:[batch_size,max_len,n_head,d_v]
q = self.w_qs(q).view(sz_b,len_q,n_head,d_k)
k = self.w_ks(k).view(sz_b,len_k,n_head,d_k)
v = self.w_vs(v).view(sz_b,len_v,n_head,d_v)
# q:[n_head,batch_size,max_len,d_k] -> [n_head*batch_size,max_len,d_k]
# k:[n_head,batch_size,max_len,d_k] -> [n_head*batch_size,max_len,d_k]
# v:[n_head,batch_size,max_len,d_v] -> [n_head*batch_size,max_len,d_v]
q = q.permute(2,0,1,3).contiguous().view(-1,len_q,d_k)
k = k.permute(2,0,1,3).contiguous().view(-1,len_k,d_k)
v = v.permute(2,0,1,3).contiguous().view(-1,len_v,d_v)
if mask is not None:
# mask:[n_head*batch_size,max_len,max_len]
mask = mask.repeat(n_head,1,1).bool()
# output:[n_head*batch_size,max_len,d_v]
# attn:[n_head*batch_size,max_len,max_len]
output,attn = self.attention(q,k,v,mask=mask)
# output:[n_head,batch_size,max_len,d_v]
output = output.view(n_head,sz_b,len_q,d_v)
# output:[batch_size,max_len,n_head*d_v]
output = output.permute(1,2,0,3).contiguous().view(sz_b,len_q,-1)
# output:[batch_size,max_len,d_model]
output = self.dropout(self.fc(output))
output = self.layer_norm(output + residual)
return output,attn
n_head,d_model,d_k,d_v = 8,100,256,325
batch_size = 4
max_len = 18
model = MultiHeadAttention(n_head,d_model,d_k,d_v)
x = torch.randn([batch_size,max_len,d_model])
mask = torch.triu(torch.ones([batch_size,max_len,max_len]),0)
output,attn = model(x,x,x,mask)
print('output.shape:',output.shape)
print('attn.shape:',attn.shape)
输出:
output.shape: torch.Size([4, 18, 100])
attn.shape: torch.Size([32, 18, 18])
3.GlobalGate
class GlobalGate(nn.Module):
def __init__(self,hidden_dim):
super().__init__()
self.hidden_dim = hidden_dim
self.head = 1
self.self_attention = MultiHeadAttention(self.head,self.hidden_dim,self.hidden_dim // self.head,
self.hidden_dim // self.head)
self.G2Gupdategate = nn.Linear(2*self.hidden_dim,self.hidden_dim,bias=True)
def forward(self,layer_output,global_matrix = None):
if global_matrix is not None:
# layer_output_selfatten:[batch_size,max_len,hidden_size]
layer_output_selfatten,_ = self.self_attention(layer_output,layer_output,layer_output)
# input_cat:[batch_size,max_len,2*hidden_size]
input_cat = torch.cat([layer_output_selfatten,global_matrix],dim = 2)
# update_gate:[batch_size,max_len,hidden_size]
update_gate = torch.sigmoid(self.G2Gupdategate(input_cat))
# new_global_matrix:[batch_size,max_len,hidden_size]
new_gobal_matrix = update_gate * layer_output_selfatten + (1 - update_gate) * global_matrix
else:
new_gobal_matrix,_ = self.self_attention(layer_output,layer_output,layer_output)
return new_gobal_matrix
hidden_dim = 100
batch_size = 4
max_len = 18
layer_output = torch.randn([batch_size,max_len,hidden_dim])
global_matrix = torch.randn([batch_size,max_len,hidden_dim])
model = GlobalGate(hidden_dim)
new_gobal_matrix = model(layer_output,global_matrix)
print('new_gobal_matrix.shape:',new_gobal_matrix.shape)
输出:
new_gobal_matrix.shape: torch.Size([4, 18, 100])