[自然语言处理nlp(一)] Transformer运算过程与代码解析

本文深入解析Transformer模型的运算过程,包括Encoder、Decoder、Attention和Feed-Forward Networks。通过分析论文中的结构图,阐述了每个模块的功能和代码实现,特别是MultiHeadAttention的计算细节。此外,介绍了模型的整体结构,强调了Decoder中特有的MaskedMultiHeadAttention以及其与Encoder的区别。
摘要由CSDN通过智能技术生成

近期开始学习NLP,边学习边整理。

本文代码主要来自地址1,结合地址2里描述的tensor大小进行理解,从论文图片开始,一步一步按照图片解析每个模块的运算过程。
在这里插入图片描述

整体结构

上图的结构可以分为编码器与解码器,可以形成以下代码。图中张量大小如下:
Inputs [batch, 最大的句长]。此时,Inputs[0, 1]表示批数据里第一句子的第二个单词。
经过Input Embedding与Positional Encoding [batch, 最大的句长, 词向量长度]。这里将每个单词映射为了词向量,因此多了一个维度。

class EncoderDecoder(nn.Module):
    """
    A standard Encoder-Decoder architecture. Base for this and many 
    other models.
    """
    def __init__(self, encoder, decoder, src_embed, tgt_embed, generator):
        super(EncoderDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_embed = src_embed
        self.tgt_embed = tgt_embed
        self.generator = generator
        
    def forward(self, src, tgt, src_mask, tgt_mask):
        "Take in and process masked src and target sequences."
        return self.decode(self.encode(src, src_mask), src_mask,
                            tgt, tgt_mask)
    
    def encode(self, src, src_mask):
        return self.encoder(self.src_embed(src), src_mask)
    
    def decode(self, memory, src_mask, tgt, tgt_mask):
        return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)

class Generator(nn.Module):
    "Define standard linear + softmax generation step."
    def __init__(self, d_model, vocab):
        super(Generator, self).__init__()
        self.proj = nn.Linear(d_model, vocab)

    def forward(self, x):
        return F.log_softmax(self.proj(x), dim=-1)

Encoder

在这里插入图片描述
解码器Encoder如上图所示,由N=6个相同的层组成

def clones(module, N):
    "Produce N identical layers."
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])

class Encoder(nn.Module
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值