【Pytorch】Transformer encoder代码实现

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np

1. token embedding

# token embedding
class TokenEmbedding(nn.Module):
    def __init__(self, vocab, d_model):
        super().__init__()
        # vocab 词表长度
        # d_model 将token 嵌入后的长度
        self.emb = nn.Embedding(vocab, d_model)
        self.d_model = d_model
    
    def forward(self, x):
        # 乘上 根号d_model 放大,让其和position encoding值范围接近
        return self.emb(x) * math.sqrt(self.d_model)

x = torch.LongTensor([[1, 2, 3]
                      ,[4, 5, 6]])
# 两个特征,每个特征在时序上有3个值/token,对这3个值进行embedding
emb = TokenEmbedding(vocab=10000, d_model=8)
res = emb(x)
res.shape
'''
    torch.Size([2, 3, 8])
'''

2. attention

# attention
def attention(query, key, value, mask=None, dropout=None):
    """  
    q 的最后维度 和k的最后维度相同 >>> d_k
    k 的第一个维度 和v的第一个维度相同
    如:q 1*2; k 3 * 2; v 3 * 6
    q * k_t >>> 1 * 3
    q * k_t * v >>> 1 * 6
    注:实际qkv可以是 bs * head_num * length * embedding_dim
    """
    d_k = query.shape[-1]

    # 转置
    key_t = key.transpose(-2, -1)
    
    score = torch.matmul(query, key_t) / math.sqrt(d_k)

    score = F.softmax(score, dim=-1)

    att_score = torch.matmul(score, value)

    return att_score

# 测试
q = torch.rand(1, 2)
k = torch.rand(3, 2)
v = torch.rand(3, 6)
attention(q, k, v).shape

'''
    torch.Size([1, 6])
'''

3. MHSA

class MHSA(nn.Module):
    def __init__(self, head_num, embedding_dim):
        super().__init__()

        assert embedding_dim % head_num == 0

        # d_k 分割后的嵌入维度
        self.d_k = embedding_dim // head_num
        self.head_num = head_num

        self.linear_q = nn.Linear(embedding_dim, embedding_dim)
        self.linear_k = nn.Linear(embedding_dim, embedding_dim)
        self.linear_v = nn.Linear(embedding_dim, embedding_dim)

        # 各头结果concat之后的linear 线性变换
        self.linear_final = nn.Linear(embedding_dim, embedding_dim)
    
    def forward(self, x):
        # x输入前应该进行layer归一化处理
        # or
        # self.ln = nn.LayerNorm(embedding_dim)
        # x = self.ln(x)
        
        bs, length, embedding_dim = x.shape

        # 得到qkv 分头
        q = self.linear_q(x).reshape(bs, length, self.head_num, self.d_k)
        k = self.linear_k(x).reshape(bs, length, self.head_num, self.d_k)
        v = self.linear_v(x).reshape(bs, length, self.head_num, self.d_k)

        # bs * head_num * length * fea_dim_prehead
        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        att_score = attention(q, k, v)

        att_score = att_score.transpose(1, 2).reshape(bs, length, -1)

        att_score = self.linear_final(att_score)

        return att_score
    

# 一个batch中,4个时序输入,每个输入的emb_dim=12
x = torch.rand(1, 4, 12)
# 6个头,emb_dim=12
mhsa = MHSA(6, 12)

mhsa(x).shape
'''
    torch.Size([1, 4, 12])
'''

4. Add&Norm 残差连接

# Add & Norm
class SublayerConnection(nn.Module):
    def __init__(self, emb_dim, dropout=0.1):
        super().__init__()

        self.layernorm = nn.LayerNorm(emb_dim)
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x, sublayer_fn):
        '''  
        :param x: MHSA 输入(未归一化)
        :param atten_score: MHSA输出
        :return :
        '''

        x = self.layernorm(x)
        atten_score = sublayer_fn(x)

        atten_score = self.dropout(atten_score)

        return x + atten_score


x = torch.rand(1, 4, 12)
mhsa = MHSA(3, 12) 
sublayer_fn = lambda x: mhsa(x)
sc = SublayerConnection(12)
sc(x, sublayer_fn).shape
'''
    torch.Size([1, 4, 12])
'''

5. FeedForward

# Feed Forward 前馈全连接层
# 注意力机制可能对复杂过程的拟合不够,增强模型拟合能力
class FeedForward(nn.Module):
    def __init__(self, emb_dim, d_mid, dropout=0.1):
        super().__init__()

        self.w1 = nn.Linear(emb_dim, d_mid)
        self.w2 = nn.Linear(d_mid, emb_dim)
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x):
        x1 = self.w1(x)
        x1 = F.relu(x1)
        x1 = self.dropout(x1)

        x2 = self.w2(x1)

        return x2
    
x = torch.rand(1, 4, 12)
ff = FeedForward(12, 32)
ff(x).shape
'''
    torch.Size([1, 4, 12])
'''

6. EncoderLayer

# 一个编码器
class EncoderLayer(nn.Module):
    def __init__(self, mhsa, feedforward, emb_dim):
        super().__init__()

        self.mhsa= mhsa
        self.feedforward = feedforward
        self.sublayerconnection_1 = SublayerConnection(emb_dim)
        self.sublayerconnection_2 = SublayerConnection(emb_dim)

    def forward(self, x):

        # MHSA的残差连接
        sublayer_fn = lambda x: self.mhsa(x)
        x = self.sublayerconnection_1(x, sublayer_fn)

        # 前馈层的残差连接
        sublayer_fn = lambda x: self.feedforward(x)
        x = self.sublayerconnection_2(x, sublayer_fn)

        return x


x = torch.rand(1, 4, 12)
mhsa = MHSA(head_num=3, embedding_dim=12)
ff = FeedForward(emb_dim=12, d_mid=32)

encoderlayer = EncoderLayer(mhsa, ff, 12)
encoderlayer(x).shape
'''
    torch.Size([1, 4, 12])
'''
  • 4
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
当然可以!下面是一个使用PyTorch实现Transformer模型的示例代码: ```python import torch import torch.nn as nn import torch.nn.functional as F class Transformer(nn.Module): def __init__(self, input_dim, hidden_dim, num_layers, num_heads): super(Transformer, self).__init__() self.embedding = nn.Embedding(input_dim, hidden_dim) self.position_encoding = PositionalEncoding(hidden_dim) self.encoder_layers = nn.ModuleList([EncoderLayer(hidden_dim, num_heads) for _ in range(num_layers)]) self.fc = nn.Linear(hidden_dim, input_dim) def forward(self, x): seq_len = x.size(1) # Embedding and Positional Encoding x = self.embedding(x) x = self.position_encoding(x) # Transformer Encoder for encoder_layer in self.encoder_layers: x = encoder_layer(x) # Output layer x = self.fc(x) x = F.log_softmax(x, dim=-1) return x class EncoderLayer(nn.Module): def __init__(self, hidden_dim, num_heads): super(EncoderLayer, self).__init__() self.multihead_attention = MultiheadAttention(hidden_dim, num_heads) self.ffn = FeedForwardNetwork(hidden_dim) def forward(self, x): # Multi-head Attention x = self.multihead_attention(x) # Feed Forward Network x = self.ffn(x) return x class MultiheadAttention(nn.Module): def __init__(self, hidden_dim, num_heads): super(MultiheadAttention, self).__init__() self.hidden_dim = hidden_dim self.num_heads = num_heads self.head_dim = hidden_dim // num_heads self.linear_q = nn.Linear(hidden_dim, hidden_dim) self.linear_k = nn.Linear(hidden_dim, hidden_dim) self.linear_v = nn.Linear(hidden_dim, hidden_dim) self.fc = nn.Linear(hidden_dim, hidden_dim) def forward(self, x): batch_size = x.size(0) q = self.linear_q(x) k = self.linear_k(x) v = self.linear_v(x) q = self._split_heads(q, batch_size) k = self._split_heads(k, batch_size) v = self._split_heads(v, batch_size) scaled_attention = self._scaled_dot_product_attention(q, k, v) scaled_attention = self._concat_heads(scaled_attention, batch_size) x = self.fc(scaled_attention) return x def _split_heads(self, x, batch_size): x = x.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) return x def _concat_heads(self, x, batch_size): x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.hidden_dim) return x def _scaled_dot_product_attention(self, q, k, v): dk = torch.sqrt(torch.tensor(self.head_dim).float()) scores = torch.matmul(q, k.transpose(-2, -1)) / dk attention_weights = F.softmax(scores, dim=-1) scaled_attention = torch.matmul(attention_weights, v) return scaled_attention class FeedForwardNetwork(nn.Module): def __init__(self, hidden_dim): super(FeedForwardNetwork, self).__init__() self.fc1 = nn.Linear(hidden_dim, hidden_dim * 4) self.fc2 = nn.Linear(hidden_dim * 4, hidden_dim) def forward(self, x): x = F.relu(self.fc1(x)) x = self.fc2(x) return x class PositionalEncoding(nn.Module): def __init__(self, hidden_dim, max_len=5000): super(PositionalEncoding, self).__init__() self.hidden_dim = hidden_dim pe = torch.zeros(max_len, hidden_dim) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) div_term = torch.exp(torch.arange(0, hidden_dim, 2).float() * (-torch.log(torch.tensor(10000.0)) / hidden_dim)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0).transpose(0, 1) self.register_buffer('pe', pe) def forward(self, x): x = x * math.sqrt(self.hidden_dim) x = x + self.pe[:x.size(0), :] return x ``` 这是一个简化的Transformer模型,其中包含了Encoder层、Multi-head Attention、Feed Forward Network和Positional Encoding等组件。你可以根据需要进行修改和扩展。 希望这个代码对你有帮助!如果你还有其他问题,请随时提问。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值