import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
1. token embedding
class TokenEmbedding(nn.Module):
def __init__(self, vocab, d_model):
super().__init__()
self.emb = nn.Embedding(vocab, d_model)
self.d_model = d_model
def forward(self, x):
return self.emb(x) * math.sqrt(self.d_model)
x = torch.LongTensor([[1, 2, 3]
,[4, 5, 6]])
emb = TokenEmbedding(vocab=10000, d_model=8)
res = emb(x)
res.shape
'''
torch.Size([2, 3, 8])
'''
2. attention
def attention(query, key, value, mask=None, dropout=None):
"""
q 的最后维度 和k的最后维度相同 >>> d_k
k 的第一个维度 和v的第一个维度相同
如:q 1*2; k 3 * 2; v 3 * 6
q * k_t >>> 1 * 3
q * k_t * v >>> 1 * 6
注:实际qkv可以是 bs * head_num * length * embedding_dim
"""
d_k = query.shape[-1]
key_t = key.transpose(-2, -1)
score = torch.matmul(query, key_t) / math.sqrt(d_k)
score = F.softmax(score, dim=-1)
att_score = torch.matmul(score, value)
return att_score
q = torch.rand(1, 2)
k = torch.rand(3, 2)
v = torch.rand(3, 6)
attention(q, k, v).shape
'''
torch.Size([1, 6])
'''
3. MHSA
class MHSA(nn.Module):
def __init__(self, head_num, embedding_dim):
super().__init__()
assert embedding_dim % head_num == 0
self.d_k = embedding_dim // head_num
self.head_num = head_num
self.linear_q = nn.Linear(embedding_dim, embedding_dim)
self.linear_k = nn.Linear(embedding_dim, embedding_dim)
self.linear_v = nn.Linear(embedding_dim, embedding_dim)
self.linear_final = nn.Linear(embedding_dim, embedding_dim)
def forward(self, x):
bs, length, embedding_dim = x.shape
q = self.linear_q(x).reshape(bs, length, self.head_num, self.d_k)
k = self.linear_k(x).reshape(bs, length, self.head_num, self.d_k)
v = self.linear_v(x).reshape(bs, length, self.head_num, self.d_k)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
att_score = attention(q, k, v)
att_score = att_score.transpose(1, 2).reshape(bs, length, -1)
att_score = self.linear_final(att_score)
return att_score
x = torch.rand(1, 4, 12)
mhsa = MHSA(6, 12)
mhsa(x).shape
'''
torch.Size([1, 4, 12])
'''
4. Add&Norm 残差连接
class SublayerConnection(nn.Module):
def __init__(self, emb_dim, dropout=0.1):
super().__init__()
self.layernorm = nn.LayerNorm(emb_dim)
self.dropout = nn.Dropout(p=dropout)
def forward(self, x, sublayer_fn):
'''
:param x: MHSA 输入(未归一化)
:param atten_score: MHSA输出
:return :
'''
x = self.layernorm(x)
atten_score = sublayer_fn(x)
atten_score = self.dropout(atten_score)
return x + atten_score
x = torch.rand(1, 4, 12)
mhsa = MHSA(3, 12)
sublayer_fn = lambda x: mhsa(x)
sc = SublayerConnection(12)
sc(x, sublayer_fn).shape
'''
torch.Size([1, 4, 12])
'''
5. FeedForward
class FeedForward(nn.Module):
def __init__(self, emb_dim, d_mid, dropout=0.1):
super().__init__()
self.w1 = nn.Linear(emb_dim, d_mid)
self.w2 = nn.Linear(d_mid, emb_dim)
self.dropout = nn.Dropout(p=dropout)
def forward(self, x):
x1 = self.w1(x)
x1 = F.relu(x1)
x1 = self.dropout(x1)
x2 = self.w2(x1)
return x2
x = torch.rand(1, 4, 12)
ff = FeedForward(12, 32)
ff(x).shape
'''
torch.Size([1, 4, 12])
'''
6. EncoderLayer
class EncoderLayer(nn.Module):
def __init__(self, mhsa, feedforward, emb_dim):
super().__init__()
self.mhsa= mhsa
self.feedforward = feedforward
self.sublayerconnection_1 = SublayerConnection(emb_dim)
self.sublayerconnection_2 = SublayerConnection(emb_dim)
def forward(self, x):
sublayer_fn = lambda x: self.mhsa(x)
x = self.sublayerconnection_1(x, sublayer_fn)
sublayer_fn = lambda x: self.feedforward(x)
x = self.sublayerconnection_2(x, sublayer_fn)
return x
x = torch.rand(1, 4, 12)
mhsa = MHSA(head_num=3, embedding_dim=12)
ff = FeedForward(emb_dim=12, d_mid=32)
encoderlayer = EncoderLayer(mhsa, ff, 12)
encoderlayer(x).shape
'''
torch.Size([1, 4, 12])
'''