手撕多头Self-attention代码

import torch
import torch.nn as nn
import math
class LayerNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-12):
        super(LayerNorm, self).__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.bias = nn.Parameter(torch.zeros(hidden_size))
        self.variance_epsilon = eps

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        var = (x - mean).pow(2).mean(-1, keepdim=True)
        x = (x - mean) / torch.sqrt(var + self.variance_epsilon)
        return self.weight * x + self.bias
    
class SelfAttention(nn.Module):
    def __init__(self, num_attention_heads, input_size, hidden_size, hidden_dropout_prob,attention_probs_dropout_prob):
        super(SelfAttention, self).__init__()
        if hidden_size % num_attention_heads != 0:
            raise ValueError("The hidden size (%d) is not a multiple of the number of attention ""heads (%d)" % (hidden_size, num_attention_heads))
        self.num_attention_heads = num_attention_heads
        self.attention_head_size = int(hidden_size / num_attention_heads)
        self.hidden_size = hidden_size
        self.query = nn.Linear(input_size, self.hidden_size)
        self.key = nn.Linear(input_size, self.hidden_size)
        self.value = nn.Linear(input_size, self.hidden_size)
        self.attn_dropout = nn.Dropout(attention_probs_dropout_prob)
        # 做完self-attention 做一个前馈全连接 LayerNorm 输出
        self.dense = nn.Linear(hidden_size, hidden_size)
        self.LayerNorm = LayerNorm(hidden_size, eps=1e-12)
        self.out_dropout = nn.Dropout(hidden_dropout_prob)
    def transpose_for_scores(self, x):
        # input [B,N,hidden_size]->-> output [B, num_attention_heads, N, attention_head_size]
        new_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_shape)
        x = x.permute(0, 2, 1, 3)
        return x
    def forward(self, x):
        query_feat = self.transpose_for_scores(self.query(x))
        key_feat = self.transpose_for_scores(self.query(x))
        value_feat = self.transpose_for_scores(self.query(x))
        attention_scores = torch.matmul(query_feat, key_feat.transpose(-1, -2))
        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
        attention_probs = nn.Softmax(dim=-1)(attention_scores)
        attention_probs = self.attn_dropout(attention_probs)
        out_feat = torch.matmul(attention_probs, value_feat)
        out_feat = out_feat.permute(0,2,3,1).contiguous()
        ori_shape = out_feat.size()[:-2]+(self.hidden_size,)
        out_feat = out_feat.view(ori_shape)
        hidden_states = self.dense(out_feat)
        hidden_states = self.out_dropout(hidden_states)
        hidden_states = self.LayerNorm(hidden_states + x)
        return query_feat
attention = SelfAttention(num_attention_heads=4,input_size=768,hidden_size=768,hidden_dropout_prob=0.1,attention_probs_dropout_prob=0.1)
input = torch.rand(4,5,768)
output = attention(input)
output.shape
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值