pytorch 多头注意力实现

记录学习过程

m_{a}= softmax\left ( \frac{qk^{T}}{\sqrt{d}} \right )v

得到多个矩阵,再将矩阵拆分,得到想要形状(这就是注意力头的参数),最后通过公式得到注意力矩阵。

class MultiHeadAttention(nn.Module):
    '''
    该层是用于计算单个注意力权重的,因此我们需要通过三层线性

    注意mask必须是张量,不然会报错
    '''
    def __init__(self,heads,d_model,dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.heads = heads
        self.d_k = d_model // heads #此为注意力放缩因子,这种设计可以根据的d_model来更好的进行缩放
    
        #layer
        self.q_l  = nn.Linear(d_model,d_model)
        self.v_l  = nn.Linear(d_model,d_model)
        self.k_l  = nn.Linear(d_model,d_model)

        self.dropout  = nn.Dropout(dropout)
        self.out      = nn.Linear(d_model,d_model)

    def attention(self, q , k, v, d_k, mask = None, dropout = None):
        '''
        input:
            q = [d_model,d_model]
            k = [d_model,d_model]
            v = [d_model,d_model]
        output:
            output = [d_model,d_model]    
        '''
        scores = torch.matmul(q,k.transpose(-2,-1)) / math.sqrt(d_k)
        if mask is not None:
            #问题
            mask = mask.unsqueeze(1)
            # 掩码创建
            scores = scores.masked_fill(mask==0,-1e9)#.to(device='cuda:0')
        scores = torch.softmax(scores,dim=-1)
        
        if dropout is not None:
            scores = dropout(scores)
        
        output = torch.matmul(scores,v)#[d_model,de_model]
        
        return output

    def forward(self,q,k,v,mask = None):
        batch_size = q.size(0)
        '''
        拆头
        [d_model,d_model] reshape为[batch_size,-1,self.heads,self.d_k]
        为什么要先重构后进行transpose操作,因为其得到的特征值要有包含原本的线性关系,直接重构会有问题
        '''
        k = self.k_l(k).view(batch_size,-1,self.heads,self.d_k)
        q = self.q_l(q).view(batch_size,-1,self.heads,self.d_k)
        v = self.v_l(v).view(batch_size,-1,self.heads,self.d_k)
        k = k.transpose(1,2)
        q = q.transpose(1,2)
        v = v.transpose(1,2)

        scores = self.attention(q,k,v,self.d_k,mask,self.dropout)#调用同类方法

        concat = scores.transpose(1,2).contiguous().view(batch_size,-1,self.d_model)
        output = self.dropout(concat)

        return output

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值