import torch
import torch.nn as nn
import math
class LayerNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-12):
super(LayerNorm, self).__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.bias = nn.Parameter(torch.zeros(hidden_size))
self.variance_epsilon = eps
def forward(self, x):
mean = x.mean(-1, keepdim=True)
var = (x - mean).pow(2).mean(-1, keepdim=True)
x = (x - mean) / torch.sqrt(var + self.variance_epsilon)
return self.weight * x + self.bias
class SelfAttention(nn.Module):
def __init__(self, num_attention_heads, input_size, hidden_size, hidden_dropout_prob,attention_probs_dropout_prob):
super(SelfAttention, self).__init__()
if hidden_size % num_attention_heads != 0:
raise ValueError("The hidden size (%d) is not a multiple of the number of attention ""heads (%d)" % (hidden_size, num_attention_heads))
self.num_attention_heads = num_attention_heads
self.attention_head_size = int(hidden_size / num_attention_heads)
self.hidden_size = hidden_size
self.query = nn.Linear(input_size, self.hidden_size)
self.key = nn.Linear(input_size, self.hidden_size)
self.value = nn.Linear(input_size, self.hidden_size)
self.attn_dropout = nn.Dropout(attention_probs_dropout_prob)
# 做完self-attention 做一个前馈全连接 LayerNorm 输出
self.dense = nn.Linear(hidden_size, hidden_size)
self.LayerNorm = LayerNorm(hidden_size, eps=1e-12)
self.out_dropout = nn.Dropout(hidden_dropout_prob)
def transpose_for_scores(self, x):
# input [B,N,hidden_size]->-> output [B, num_attention_heads, N, attention_head_size]
new_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_shape)
x = x.permute(0, 2, 1, 3)
return x
def forward(self, x):
query_feat = self.transpose_for_scores(self.query(x))
key_feat = self.transpose_for_scores(self.query(x))
value_feat = self.transpose_for_scores(self.query(x))
attention_scores = torch.matmul(query_feat, key_feat.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
attention_probs = nn.Softmax(dim=-1)(attention_scores)
attention_probs = self.attn_dropout(attention_probs)
out_feat = torch.matmul(attention_probs, value_feat)
out_feat = out_feat.permute(0,2,3,1).contiguous()
ori_shape = out_feat.size()[:-2]+(self.hidden_size,)
out_feat = out_feat.view(ori_shape)
hidden_states = self.dense(out_feat)
hidden_states = self.out_dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + x)
return query_feat
attention = SelfAttention(num_attention_heads=4,input_size=768,hidden_size=768,hidden_dropout_prob=0.1,attention_probs_dropout_prob=0.1)
input = torch.rand(4,5,768)
output = attention(input)
output.shape
手撕多头Self-attention代码
最新推荐文章于 2024-09-14 09:59:20 发布