简化的 `Transformer` 编码器(Encoder)和解码器(Decoder)的 PyTorch 代码示例

Transformer 是由 Google 提出的用于自然语言处理(NLP)任务的一个深度学习模型架构,它基于自注意力(self-attention)机制。以下是一个简化的 Transformer 编码器(Encoder)和解码器(Decoder)的 PyTorch 代码示例。请注意,这个示例仅用于教学目的,并未包含完整的 Transformer 架构(如位置编码、层归一化、残差连接等)。

import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiHeadAttention(nn.Module):
    # 这是一个简化的多头注意力机制实现,用于演示
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.d_model = d_model
        
        assert d_model % self.num_heads == 0
        
        self.depth = d_model // self.num_heads
        
        self.wq = nn.Linear(d_model, d_model)
        self.wk = nn.Linear(d_model, d_model)
        self.wv = nn.Linear(d_model, d_model)
        
        self.dense = nn.Linear(d_model, d_model)
        
    def split_heads(self, x, batch_size):
        x = x.reshape(batch_size, -1, self.num_heads, self.depth)
        return x.permute(0, 2, 1, 3)
    
    def forward(self, v, k, q, mask):
        batch_size = q.shape[0]
        
        q = self.wq(q)  # (batch_size, seq_len, d_model)
        k = self.wk(k)  # (batch_size, seq_len, d_model)
        v = self.wv(v)  # (batch_size, seq_len, d_model)
        
        q = self.split_heads(q, batch_size)  # (batch_size, num_heads, seq_len_q, depth)
        k = self.split_heads(k, batch_size)  # (batch_size, num_heads, seq_len_k, depth)
        v = self.split_heads(v, batch_size)  # (batch_size, num_heads, seq_len_v, depth)
        
        # Scaled Dot-Product Attention
        scaled_attention_logits = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.depth, dtype=torch.float32))
        
        if mask is not None:
            scaled_attention_logits += (mask * -1e9)  # Add the mask to the scaled tensor.
            
        attention_weights = nn.Softmax(dim=-1)(scaled_attention_logits)  # (batch_size, num_heads, seq_len_q, seq_len_k)
        
        output = torch.matmul(attention_weights, v)  # (batch_size, num_heads, seq_len_q, depth)
        output = output.permute(0, 2, 1, 3).contiguous()  # (batch_size, seq_len_q, num_heads, depth)
        output = output.reshape(batch_size, -1, self.d_model)  # (batch_size, seq_len_q, d_model)
        
        output = self.dense(output)  # (batch_size, seq_len_q, d_model)
        
        return output, attention_weights

# 这里省略了 Transformer 的其他组件,如前馈神经网络、位置编码、层归一化等。

# 如果你想要完整的 Transformer 编码器或解码器,你还需要实现这些组件,并把它们组合在一起。

注意:上面的 MultiHeadAttention 类仅实现了多头注意力机制的核心部分,并没有包含完整的 Transformer 编码器或解码器。一个完整的 Transformer 编码器通常包括一个多头注意力层、一个前馈神经网络(FFN)以及可能的层归一化和残差连接。解码器则通常包括两个多头注意力层(一个自注意力层和一个编码器-解码器注意力层)以及一个前馈神经网络。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Dalao_zzl

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值