小黑生活丰富多彩:Attention机制

1.ScaledDotProductAttention

在这里插入图片描述
在这里插入图片描述

import torch.nn as nn
import numpy as np
class ScaledDotProductAttention(nn.Module):
    
    def __init__(self,temperature,hidden_dim,attn_dropout = 0.1):
        super().__init__()
        self.temperature = temperature
        self.dropout = nn.Dropout(attn_dropout)
        self.softmax = nn.Softmax(dim = 2)
        
    def forward(self,q,k,v,mask = None):
        # q,k:[batch_size,max_len,d_k]
        # v:[batch_size,max_len,d_v]
        # attn:[batch_size,max_len,max_len]
        attn = torch.bmm(q,k.transpose(1,2))
        attn = attn / self.temperature
        
        if mask is not None:
            mask = mask.bool()
            attn = attn.masked_fill(mask,-np.inf)
        
        attn = self.softmax(attn)
        attn = self.dropout(attn)
        # output:[batch_size,max_len,d_v]
        output = torch.bmm(attn,v)
        
        return output,attn

temperature = 10
hidden = 100
batch_size = 4
max_len = 18
dim = 300
q = torch.randn([batch_size,max_len,dim])
k = torch.randn([batch_size,max_len,dim])
v = torch.randn([batch_size,max_len,dim])
mask = torch.triu(torch.ones([batch_size,max_len,max_len]),0)
model = ScaledDotProductAttention(temperature,dim)
output,attn = model(q,k,v)
print('attn.shape:',attn.shape)
print('output.shape:',output.shape)
输出:

attn.shape: torch.Size([4, 18, 18])
output.shape: torch.Size([4, 18, 300])

2.MultiHeadAttention

整体:

在这里插入图片描述

Multi-Head Attention部分

在这里插入图片描述

class MultiHeadAttention(nn.Module):
    def __init__(self,n_head,d_model,d_k,d_v,dropout = 0.1):
        super().__init__()
        
        self.n_head = n_head
        self.d_k = d_k
        self.d_v = d_v
        
        self.w_qs = nn.Linear(d_model,n_head * d_k)
        self.w_ks = nn.Linear(d_model,n_head * d_k)
        self.w_vs = nn.Linear(d_model,n_head * d_v)
        
        nn.init.normal_(self.w_qs.weight,mean = 0,std = np.sqrt(2.0 / (d_model + d_k)))
        nn.init.normal_(self.w_ks.weight,mean = 0,std = np.sqrt(2.0 / (d_model + d_k)))
        nn.init.normal_(self.w_vs.weight,mean = 0,std = np.sqrt(2.0 / (d_model + d_v)))
        
        self.attention = ScaledDotProductAttention(temperature=np.power(d_k,0.5),hidden_dim = self.d_v)
        self.layer_norm = nn.LayerNorm(d_model)
        
        self.fc = nn.Linear(n_head * d_v,d_model)
        nn.init.xavier_normal_(self.fc.weight)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self,q,k,v,mask = None):
        # q:[batch_size,max_len,d_k]
        # k:[batch_size,max_len,d_k]
        # v:[batch_size,max_len,d_v]
        # mask:[batch_size,max_len,max_len]
        d_k,d_v,n_head = self.d_k,self.d_v,self.n_head
        
        sz_b,len_q,_ = q.size()
        sz_b,len_k,_ = k.size()
        sz_b,len_v,_ = v.size()
        
        residual = q
        
        # q:[batch_size,max_len,n_head,d_k]
        # k:[batch_size,max_len,n_head,d_k]
        # v:[batch_size,max_len,n_head,d_v]
        q = self.w_qs(q).view(sz_b,len_q,n_head,d_k)
        k = self.w_ks(k).view(sz_b,len_k,n_head,d_k)
        v = self.w_vs(v).view(sz_b,len_v,n_head,d_v)
        
        # q:[n_head,batch_size,max_len,d_k] -> [n_head*batch_size,max_len,d_k]
        # k:[n_head,batch_size,max_len,d_k] -> [n_head*batch_size,max_len,d_k]
        # v:[n_head,batch_size,max_len,d_v] -> [n_head*batch_size,max_len,d_v]
        q = q.permute(2,0,1,3).contiguous().view(-1,len_q,d_k)
        k = k.permute(2,0,1,3).contiguous().view(-1,len_k,d_k)
        v = v.permute(2,0,1,3).contiguous().view(-1,len_v,d_v)
        
        if mask is not None:
            # mask:[n_head*batch_size,max_len,max_len]
            mask = mask.repeat(n_head,1,1).bool()
        # output:[n_head*batch_size,max_len,d_v]
        # attn:[n_head*batch_size,max_len,max_len]
        output,attn = self.attention(q,k,v,mask=mask)
        
        # output:[n_head,batch_size,max_len,d_v]
        output = output.view(n_head,sz_b,len_q,d_v)
        # output:[batch_size,max_len,n_head*d_v]
        output = output.permute(1,2,0,3).contiguous().view(sz_b,len_q,-1)
        # output:[batch_size,max_len,d_model]
        output = self.dropout(self.fc(output))
        output = self.layer_norm(output + residual)
        return output,attn
n_head,d_model,d_k,d_v = 8,100,256,325
batch_size = 4
max_len = 18
model = MultiHeadAttention(n_head,d_model,d_k,d_v)
x = torch.randn([batch_size,max_len,d_model])
mask = torch.triu(torch.ones([batch_size,max_len,max_len]),0)
output,attn = model(x,x,x,mask)
print('output.shape:',output.shape)
print('attn.shape:',attn.shape)
输出:

output.shape: torch.Size([4, 18, 100])
attn.shape: torch.Size([32, 18, 18])

3.GlobalGate

class GlobalGate(nn.Module):
    
    def __init__(self,hidden_dim):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.head = 1
        self.self_attention = MultiHeadAttention(self.head,self.hidden_dim,self.hidden_dim // self.head,
                                                 self.hidden_dim // self.head)
        self.G2Gupdategate = nn.Linear(2*self.hidden_dim,self.hidden_dim,bias=True)
    
    def forward(self,layer_output,global_matrix = None):
        if global_matrix is not None:
            # layer_output_selfatten:[batch_size,max_len,hidden_size]
            layer_output_selfatten,_ = self.self_attention(layer_output,layer_output,layer_output)
            # input_cat:[batch_size,max_len,2*hidden_size]
            input_cat = torch.cat([layer_output_selfatten,global_matrix],dim = 2)
            # update_gate:[batch_size,max_len,hidden_size]
            update_gate = torch.sigmoid(self.G2Gupdategate(input_cat))
            # new_global_matrix:[batch_size,max_len,hidden_size]
            new_gobal_matrix = update_gate * layer_output_selfatten + (1 - update_gate) * global_matrix
        else:
            new_gobal_matrix,_ = self.self_attention(layer_output,layer_output,layer_output)
        return new_gobal_matrix
hidden_dim = 100
batch_size = 4
max_len = 18
layer_output = torch.randn([batch_size,max_len,hidden_dim])
global_matrix = torch.randn([batch_size,max_len,hidden_dim])
model = GlobalGate(hidden_dim)
new_gobal_matrix = model(layer_output,global_matrix)
print('new_gobal_matrix.shape:',new_gobal_matrix.shape)
输出:

new_gobal_matrix.shape: torch.Size([4, 18, 100])

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值