从头实现一个完整的Transformer模型

引言

        当我决定深入研究Transformer架构时,我常常在阅读或观看网上教程时感到挫败,因为总觉得它们缺少一些关键内容:

  • Tensorflow或Pytorch的官方教程使用了自己的API,保持在高层次,这迫使我不得不深入它们的代码库去了解底层实现。这非常耗时,而且阅读成千上万行代码也并不容易。
  • 其他使用自定义代码的教程(文章末尾有链接)通常对用例过于简化,没有涉及诸如可变长度序列批处理的掩码处理等概念。

因此,我决定自己编写一个Transformer,以确保我理解这些概念,并能够将其应用于任何数据集。在本文中,我们将采用一种系统的方法,逐层、逐块地实现一个Transformer。

        那么,为什么不使用TF/Pytorch的实现呢?本文的目的是教育性的,我并不打算超越Pytorch或Tensorflow的实现。我确实认为Transformer的理论和背后的代码并不容易理解,这也是我希望通过这个一步步的教程,让你更好地掌握这些概念,并在以后编写自己的代码时感到更自在的原因。从头开始构建自己的Transformer的另一个原因是,它将让你完全理解如何使用上述API。如果我们查看Pytorch中Transformer类的forward()方法实现,你会看到很多晦涩的关键词,比如:

        如果你已经熟悉这些关键词,那么你可以愉快地跳过这篇文章。否则,本文将带你逐一理解这些关键词及其背后的概念。

Transformer简短介绍

        如果你听说过ChatGPT或Gemini,那么你已经遇到过Transformer了。事实上,ChatGPT中的“T”代表Transformer。该架构首次提出是在2017年,由谷歌研究人员在论文《Attention is All You Need》中提出。这是一种非常革命性的架构,因为以前的模型(用于序列到序列学习的机器翻译、语音转文本等)依赖于计算上昂贵的RNN,它们必须逐步处理序列。而Transformer只需要一次性地查看整个序列,将时间复杂度(Sequential Operations)从 O(n) 降低到 O(1)。

        对于其他的复杂度指标,我们先在此做最简单的理解,相信你在看完本篇文章后,将会有更深刻的认识。Complexity per Layer代表每层(主要)的计算复杂度,对于自注意力(Self-Attention),假设输入序列的长度为n,特征维度为d,首先会生成3个(Q, K, V)长度为n,特征维度为d的嵌入矩阵,Q与K(的转置)做矩阵乘法得到n x n维度的注意力图,代表每个位置与其他位置的注意力权重,该步骤复杂度为O(n^2 d),然后该注意力图与V做矩阵乘法得到更新后的n*d维度的输出序列特征表示,当然还有后面的线性变换层等操作,但主要的复杂度来自注意力图的计算。Sequential Operations代表每层需要依次进行的操作数,顺序操作数为 O(1),意味着这些操作可以并行化;顺序操作数为 O(n),代表需要依次进行n次操作。Maximum Path Length表示信息在层中传播所需经过的最大路径长度,Self-Attention只需操作一次就可以全局更新每个位置上的特征,而卷积需要log_k (n)次操作(受限于卷积窗口的大小)才可以更新每个位置,尽管这些卷积操作可以并行计算。

        通过这个表格可以看出,不同类型的层在计算复杂度、并行化能力和信息传播速度上的区别。自注意力层(Self-Attention)在并行化和信息传播速度上有明显优势,但计算复杂度较高。而循环层(Recurrent)虽然计算复杂度较低,但顺序操作数和最大路径长度较长,限制了并行化能力和信息传播速度。卷积层(Convolutional)在这三者之间,综合了计算复杂度和并行化能力。

Transformer的总体结构如下:

        如果你每次看到这个图都不知所云,我可以先告诉你几点:

        第一:在很多时候,包括你看到的很多文章里的模型,只用到了上图的编码器部分,用于提取输入序列的特征(如Vision Transformer, VIT)。而解码器主要用于构建完整的生成模型,用于像语言翻译、文字续写之类的生成任务。如果你的目的是分类这样的判别式任务,你大概率不会接触到Transformer解码器。

        第二:假设你的任务是一个文字翻译任务,当你在训练网络的时候,你的input会是一个经编码后的序列,假设没编码前的序列是“我很帅”,那么,你的解码器的输入会是GT(“I am handsome”),然后一次性预测GT中每个位置的词,当然这里会用到Masked Attention来保证解码器在预测“I”时不会看到“ am handsome”。这在训练时是并行执行的,与推理时的自回归方式不同。

        你需要明白,当我们说Transformer的时候,就是在说上面这个结构,他包含一个编码器和一个解码器,而不仅仅是编码器,更不是在说注意力机制。这点很重要,特别是当你来自计算机视觉领域,而不是自然语言处理。(手动狗头)

        好,现在让我们进入正题。

多头自注意力

        我们将实现的第一个块实际上是Transformer中最重要的部分,它被称为多头注意。让我们看看它在整个体系结构中的位置

        注意是一种机制,实际上并不是Transformer所特有的,它已经在RNN sequence-to-sequence模型中使用。

import torch
import torch.nn as nn
import math

class MultiHeadAttention(nn.Module):
    def __init__(self, hidden_dim=256, num_heads=4):
        """
        input_dim: Dimensionality of the input.
        num_heads: The number of attention heads to split the input into.
        """
        super(MultiHeadAttention, self).__init__()
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads
        assert hidden_dim % num_heads == 0, "Hidden dim must be divisible by num heads"
        self.Wv = nn.Linear(hidden_dim, hidden_dim, bias=False) # the Value part
        self.Wk = nn.Linear(hidden_dim, hidden_dim, bias=False) # the Key part
        self.Wq = nn.Linear(hidden_dim, hidden_dim, bias=False) # the Query part
        self.Wo = nn.Linear(hidden_dim, hidden_dim, bias=False) # the output layer
        
        
    def check_sdpa_inputs(self, x):
        assert x.size(1) == self.num_heads, f"Expected size of x to be ({-1, self.num_heads, -1, self.hidden_dim // self.num_heads}), got {x.size()}"
        assert x.size(3) == self.hidden_dim // self.num_heads
        
        
    def scaled_dot_product_attention(self, query, key, value, 
            attention_mask=None, key_padding_mask=None):
        """
        query : (batch_size, num_heads, query_sequence_length, hidden_dim//num_heads)
        key : (batch_size, num_heads, key_sequence_length, hidden_dim//num_heads)
        value :  (batch_size, num_heads, key_sequence_length, hidden_dim//num_heads)
        attention_mask : (query_sequence_length, key_sequence_length)
        key_padding_mask : (sequence_length, key_sequence_length)
    
        """
        self.check_sdpa_inputs(query)
        self.check_sdpa_inputs(key)
        self.check_sdpa_inputs(value)
        
        
        d_k = query.size(-1)
        tgt_len, src_len = query.size(-2), key.size(-2)

        
        # logits = (B, H, tgt_len, E) * (B, H, E, src_len) = (B, H, tgt_len, src_len)
        logits = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k) 
        
        # Attention mask here
        if attention_mask is not None:
            if attention_mask.dim() == 2:
                assert attention_mask.size() == (tgt_len, src_len)
                attention_mask = attention_mask.unsqueeze(0)
                logits = logits + attention_mask
            else:
                raise ValueError(f"Attention mask size {attention_mask.size()}")
        
                
        # Key mask here
        if key_padding_mask is not None:
            key_padding_mask = key_padding_mask.unsqueeze(1).unsqueeze(2) # Broadcast over batch size, num heads
            logits = logits + key_padding_mask
        
        
        attention = torch.softmax(logits, dim=-1)
        output = torch.matmul(attention, value) # (batch_size, num_heads, sequence_length, hidden_dim)
        
        return output, attention

    
    def split_into_heads(self, x, num_heads):
        batch_size, seq_length, hidden_dim = x.size()
        x = x.view(batch_size, seq_length, num_heads, hidden_dim // num_heads)
        
        return x.transpose(1, 2) # Final dim will be (batch_size, num_heads, seq_length, , hidden_dim // num_heads)

    def combine_heads(self, x):
        batch_size, num_heads, seq_length, head_hidden_dim = x.size()
        return x.transpose(1, 2).contiguous().view(batch_size, seq_length, num_heads * head_hidden_dim)
        
    
    def forward(self, q, k, v, attention_mask=None, key_padding_mask=None):
        """
        q :  (batch_size, query_sequence_length, hidden_dim)
        k :  (batch_size, key_sequence_length, hidden_dim)
        v :  (batch_size, key_sequence_length, hidden_dim)
        attention_mask :  (query_sequence_length, key_sequence_length)
        key_padding_mask :  (sequence_length, key_sequence_length)
       
        """
        q = self.Wq(q)
        k = self.Wk(k)
        v = self.Wv(v)

        q = self.split_into_heads(q, self.num_heads)
        k = self.split_into_heads(k, self.num_heads)
        v = self.split_into_heads(v, self.num_heads)
        
        # attn_values, attn_weights = self.multihead_attn(q, k, v, attn_mask=attention_mask)
        attn_values, attn_weights  = self.scaled_dot_product_attention(
            query=q, 
            key=k, 
            value=v, 
            attention_mask=attention_mask,
            key_padding_mask=key_padding_mask,
        )
        grouped = self.combine_heads(attn_values)
        output = self.Wo(grouped)
        
        self.attention_weigths = attn_weights
        
        return output

在这里,我们需要解释几个概念。 

1) Queries, Keys and Values.

查询(Queries)是你试图匹配的信息,键(Keys)和值(Values)是存储的信息。就像使用字典一样:无论何时使用Python字典,如果您的查询与字典键不匹配,则不会返回任何内容。但是,如果我们希望字典返回的信息非常接近,该怎么办呢?如果我们有:

d = {"panther": 1, "bear": 10, "dog":3}
d["wolf"] = 0.2*d["panther"] + 0.7*d["dog"] + 0.1*d["bear"]

这基本上就是注意力的意义所在:查看数据的不同部分,并将它们混合起来,以获得一个综合的答案。

代码的相关部分是这个,我们计算查询和键之间的关注权重,得到注意力图,然后将注意力图与值相乘,得到最终融合的结果。

logits = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k) # 计算注意力图
attention = torch.softmax(logits, dim=-1)
output = torch.matmul(attention, value) # (batch_size, num_heads, sequence_length, hidden_dim)

2)注意力掩码和填充

        在对序列输入的部分进行注意力操作时,我们不希望包含无用或禁止的信息。无用的信息例如填充符号:填充符号用于将批处理中所有序列对齐到相同的序列长度,这些信息应被模型忽略。我们将在最后一节详细讨论这一点。禁止的信息稍微复杂一些。在训练时,模型学习编码输入序列,并将目标对齐到输入。然而,推理过程中涉及查看之前生成的标记以预测下一个标记(想想ChatGPT中的文本生成),我们需要在训练过程中应用相同的规则。

        这就是为什么我们应用因果掩码(causal mask),以确保在每个时间步,目标只能看到来自过去的信息。下面是应用掩码的对应部分(计算掩码的过程将在最后介绍)。

if attention_mask is not None:
    if attention_mask.dim() == 2:
        assert attention_mask.size() == (tgt_len, src_len)
        attention_mask = attention_mask.unsqueeze(0)
        logits = logits + attention_mask

位置编码

这对应于Transformer中的以下部分:

        当接收和处理输入时,Transformer对顺序没有概念,因为它将序列视为一个整体,这与RNN的处理方式不同。因此,我们需要添加一些时间顺序的信息,使Transformer能够学习到依赖关系。位置编码的具体工作细节不在本文的讨论范围内,但你可以阅读原始论文来了解更多内容。

# Taken from https://pytorch.org/tutorials/beginner/transformer_tutorial.html#define-the-model
class PositionalEncoding(nn.Module):

    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        
        self.register_buffer('pe', pe)

    def forward(self, x):
        """
        Arguments:
            x: Tensor, shape ``[batch_size, seq_len, embedding_dim]``
        """
        x = x + self.pe[:, :x.size(1), :]
        return x

编码器

我们即将完成一个完整的编码器!编码器是Transformer的左半部分。

我们将在代码中添加一小部分,这是Feed Forward部分:

class PositionWiseFeedForward(nn.Module):
    def __init__(self, d_model: int, d_ff: int):
        super(PositionWiseFeedForward, self).__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.fc2(self.relu(self.fc1(x)))

把所有这些组合在一起,我们得到了一个encoder 模块!

class EncoderBlock(nn.Module):
    def __init__(self, n_dim: int, dropout: float, n_heads: int):
        super(EncoderBlock, self).__init__()
        self.mha = MultiHeadAttention(hidden_dim=n_dim, num_heads=n_heads)
        self.norm1 = nn.LayerNorm(n_dim)
        self.ff = PositionWiseFeedForward(n_dim, n_dim)
        self.norm2 = nn.LayerNorm(n_dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, src_padding_mask=None):
        assert x.ndim==3, "Expected input to be 3-dim, got {}".format(x.ndim)
        att_output = self.mha(x, x, x, key_padding_mask=src_padding_mask)
        x = x + self.dropout(self.norm1(att_output))
        
        ff_output = self.ff(x)
        output = x + self.norm2(ff_output)
       
        return output

如图所示,编码器实际上包含N个编码器块或层,以及用于输入的嵌入层。因此,让我们通过添加嵌入、位置编码和编码器块来创建一个编码器:

class Encoder(nn.Module):
    def __init__(
            self, 
            vocab_size: int, 
            n_dim: int, 
            dropout: float, 
            n_encoder_blocks: int,
            n_heads: int):
        
        super(Encoder, self).__init__()
        self.n_dim = n_dim

        self.embedding = nn.Embedding(
            num_embeddings=vocab_size, 
            embedding_dim=n_dim
        )
        self.positional_encoding = PositionalEncoding(
            d_model=n_dim, 
            dropout=dropout
        )    
        self.encoder_blocks = nn.ModuleList([
            EncoderBlock(n_dim, dropout, n_heads) for _ in range(n_encoder_blocks)
        ])
        
        
    def forward(self, x, padding_mask=None):
        x = self.embedding(x) * math.sqrt(self.n_dim)
        x = self.positional_encoding(x)
        for block in self.encoder_blocks:
            x = block(x=x, src_padding_mask=padding_mask)
        return x

解码器

class DecoderBlock(nn.Module):
    def __init__(self, n_dim: int, dropout: float, n_heads: int):
        super(DecoderBlock, self).__init__()
        
        # The first Multi-Head Attention has a mask to avoid looking at the future
        self.self_attention = MultiHeadAttention(hidden_dim=n_dim, num_heads=n_heads)
        self.norm1 = nn.LayerNorm(n_dim)
        
        # The second Multi-Head Attention will take inputs from the encoder as key/value inputs
        self.cross_attention = MultiHeadAttention(hidden_dim=n_dim, num_heads=n_heads)
        self.norm2 = nn.LayerNorm(n_dim)
        
        self.ff = PositionWiseFeedForward(n_dim, n_dim)
        self.norm3 = nn.LayerNorm(n_dim)
        # self.dropout = nn.Dropout(dropout)
        
        
    def forward(self, tgt, memory, tgt_mask=None, tgt_padding_mask=None, memory_padding_mask=None):
        
        masked_att_output = self.self_attention(
            q=tgt, k=tgt, v=tgt, attention_mask=tgt_mask, key_padding_mask=tgt_padding_mask)
        x1 = tgt + self.norm1(masked_att_output)
        
        cross_att_output = self.cross_attention(
            q=x1, k=memory, v=memory, attention_mask=None, key_padding_mask=memory_padding_mask)
        x2 = x1 + self.norm2(cross_att_output)
        
        ff_output = self.ff(x2)
        output = x2 + self.norm3(ff_output)

        
        return output

class Decoder(nn.Module):
    def __init__(
        self, 
        vocab_size: int, 
        n_dim: int, 
        dropout: float, 
        n_decoder_blocks: int,
        n_heads: int):
        
        super(Decoder, self).__init__()
        
        self.embedding = nn.Embedding(
            num_embeddings=vocab_size, 
            embedding_dim=n_dim,
            padding_idx=0
        )
        self.positional_encoding = PositionalEncoding(
            d_model=n_dim, 
            dropout=dropout
        )
          
        self.decoder_blocks = nn.ModuleList([
            DecoderBlock(n_dim, dropout, n_heads) for _ in range(n_decoder_blocks)
        ])
        
        
    def forward(self, tgt, memory, tgt_mask=None, tgt_padding_mask=None, memory_padding_mask=None):
        x = self.embedding(tgt)
        x = self.positional_encoding(x)

        for block in self.decoder_blocks:
            x = block(
                x, 
                memory, 
                tgt_mask=tgt_mask, 
                tgt_padding_mask=tgt_padding_mask, 
                memory_padding_mask=memory_padding_mask)
        return x

Padding & Masking

请记住在多头注意力部分我们提到的在进行注意力操作时排除输入的某些部分。

在训练过程中,我们会考虑输入和目标的批次,其中每个实例可能有不同的长度。考虑下面这个例子,我们批处理了4个单词:banana、watermelon、pear、blueberry。为了将它们作为一个批次处理,我们需要将所有单词对齐到最长单词(watermelon)的长度。因此,我们会在每个单词后添加一个额外的填充标记(PAD),使它们的长度都与watermelon相同。

在下图中,上表表示原始数据,下表表示编码后的版本:

在我们的例子中,我们希望从计算的注意力权重中排除填充索引。因此,我们可以计算一个掩码如下,这对于source(src)和target(tar)都一样:

padding_mask = (x == PAD_IDX)

那么,现在来说说因果掩码。如果我们希望模型在每个时间步只关注过去的步骤,这意味着对于每个时间步 T,模型只能关注每个 t(从1到 T)的步骤。这实际上是一个双重循环,因此我们可以使用矩阵来计算这一点:

def generate_square_subsequent_mask(size: int):
      """Generate a triangular (size, size) mask. From PyTorch docs."""
      mask = (1 - torch.triu(torch.ones(size, size), diagonal=1)).bool()
      mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
      return mask

案例研究:一个单词反转Transformer

现在,让我们通过将各部分结合在一起构建我们的Transformer!

在我们的用例中,我们将使用一个非常简单的数据集来展示Transformer是如何实际学习的。

“但是,为什么要用Transformer来反转单词?我已经知道如何用Python的 word[::-1] 来做到这一点了!”

这里的目标是看看Transformer的注意力机制是否有效。我们期望看到的是,当给定一个输入序列时,注意力权重从右向左移动。如果是这样,这意味着我们的Transformer学会了一种非常简单的语法,即从右到左读取,并且在进行实际的语言翻译时可以推广到更复杂的语法。

首先,让我们从自定义的Transformer类开始:

import torch
import torch.nn as nn
import math

from .encoder import Encoder
from .decoder import Decoder


class Transformer(nn.Module):
    def __init__(self, **kwargs):
        super(Transformer, self).__init__()
        
        for k, v in kwargs.items():
            print(f" * {k}={v}")
        
        self.vocab_size = kwargs.get('vocab_size')
        self.model_dim = kwargs.get('model_dim')
        self.dropout = kwargs.get('dropout')
        self.n_encoder_layers = kwargs.get('n_encoder_layers')
        self.n_decoder_layers = kwargs.get('n_decoder_layers')
        self.n_heads = kwargs.get('n_heads')
        self.batch_size = kwargs.get('batch_size')
        self.PAD_IDX = kwargs.get('pad_idx', 0)

        self.encoder = Encoder(
            self.vocab_size, self.model_dim, self.dropout, self.n_encoder_layers, self.n_heads)
        self.decoder = Decoder(
            self.vocab_size, self.model_dim, self.dropout, self.n_decoder_layers, self.n_heads)
        self.fc = nn.Linear(self.model_dim, self.vocab_size)
        

    @staticmethod    
    def generate_square_subsequent_mask(size: int):
            """Generate a triangular (size, size) mask. From PyTorch docs."""
            mask = (1 - torch.triu(torch.ones(size, size), diagonal=1)).bool()
            mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
            return mask


    def encode(
            self, 
            x: torch.Tensor, 
        ) -> torch.Tensor:
        """
        Input
            x: (B, S) with elements in (0, C) where C is num_classes
        Output
            (B, S, E) embedding
        """

        mask = (x == self.PAD_IDX).float()
        encoder_padding_mask = mask.masked_fill(mask == 1, float('-inf'))
        
        # (B, S, E)
        encoder_output = self.encoder(
            x, 
            padding_mask=encoder_padding_mask
        )  
        
        return encoder_output, encoder_padding_mask
    
    
    def decode(
            self, 
            tgt: torch.Tensor, 
            memory: torch.Tensor, 
            memory_padding_mask=None
        ) -> torch.Tensor:
        """
        B = Batch size
        S = Source sequence length
        L = Target sequence length
        E = Model dimension
        
        Input
            encoded_x: (B, S, E)
            y: (B, L) with elements in (0, C) where C is num_classes
        Output
            (B, L, C) logits
        """
        
        mask = (tgt == self.PAD_IDX).float()
        tgt_padding_mask = mask.masked_fill(mask == 1, float('-inf'))

        decoder_output = self.decoder(
            tgt=tgt, 
            memory=memory, 
            tgt_mask=self.generate_square_subsequent_mask(tgt.size(1)), 
            tgt_padding_mask=tgt_padding_mask, 
            memory_padding_mask=memory_padding_mask,
        )  
        output = self.fc(decoder_output)  # shape (B, L, C)
        return output

        
        
    def forward(
            self, 
            x: torch.Tensor, 
            y: torch.Tensor, 
        ) -> torch.Tensor:
        """
        Input
            x: (B, Sx) with elements in (0, C) where C is num_classes
            y: (B, Sy) with elements in (0, C) where C is num_classes
        Output
            (B, L, C) logits
        """
        
        # Encoder output shape (B, S, E)
        encoder_output, encoder_padding_mask = self.encode(x)  

        # Decoder output shape (B, L, C)
        decoder_output = self.decode(
            tgt=y, 
            memory=encoder_output, 
            memory_padding_mask=encoder_padding_mask
        )  
        
        return decoder_output

使用贪婪解码进行推理

        我们需要添加一个方法,使其类似于scikit-learn中的 model.predict。其目标是让模型在给定输入时动态输出预测结果。在推理过程中,没有目标值:模型通过关注输出来开始生成一个标记,并使用自己的预测来继续生成后续标记。这也是为什么这些模型常被称为自回归模型,因为它们使用过去的预测来预测下一个标记。

        贪婪解码的问题在于它在每一步都选择概率最高的标记。如果最初几个标记完全错误,这可能导致非常糟糕的预测。还有其他解码方法,例如Beam搜索,它会考虑一份候选序列的短列表(可以理解为在每个时间步保留前k个标记,而不是仅选择argmax),并返回总概率最高的序列。

        现在,让我们实现贪婪解码并将其添加到我们的Transformer模型中:

def predict(
            self,
            x: torch.Tensor,
            sos_idx: int=1,
            eos_idx: int=2,
            max_length: int=None
        ) -> torch.Tensor:
        """
        Method to use at inference time. Predict y from x one token at a time. This method is greedy
        decoding. Beam search can be used instead for a potential accuracy boost.

        Input
            x: str
        Output
            (B, L, C) logits
        """

        # Pad the tokens with beginning and end of sentence tokens
        x = torch.cat([
            torch.tensor([sos_idx]), 
            x, 
            torch.tensor([eos_idx])]
        ).unsqueeze(0)

        encoder_output, mask = self.transformer.encode(x) # (B, S, E)
        
        if not max_length:
            max_length = x.size(1)

        outputs = torch.ones((x.size()[0], max_length)).type_as(x).long() * sos_idx
        for step in range(1, max_length):
            y = outputs[:, :step]
            probs = self.transformer.decode(y, encoder_output)
            output = torch.argmax(probs, dim=-1)
            
            # Uncomment if you want to see step by step predicitons
            # print(f"Knowing {y} we output {output[:, -1]}")

            if output[:, -1].detach().numpy() in (eos_idx, sos_idx):
                break
            outputs[:, step] = output[:, -1]
            
        
        return outputs

创建玩具数据

我们定义一个小数据集,该数据集将单词反转,例如“helloworld”将返回“dlrowolleh”:

import numpy as np
import torch
from torch.utils.data import Dataset


np.random.seed(0)

def generate_random_string():
    len = np.random.randint(10, 20)
    return "".join([chr(x) for x in np.random.randint(97, 97+26, len)])

class ReverseDataset(Dataset):
    def __init__(self, n_samples, pad_idx, sos_idx, eos_idx):
        super(ReverseDataset, self).__init__()
        self.pad_idx = pad_idx
        self.sos_idx = sos_idx
        self.eos_idx = eos_idx
        self.values = [generate_random_string() for _ in range(n_samples)]
        self.labels = [x[::-1] for x in self.values]

    def __len__(self):
        return len(self.values)  # number of samples in the dataset

    def __getitem__(self, index):
        return self.text_transform(self.values[index].rstrip("\n")), \
            self.text_transform(self.labels[index].rstrip("\n"))
        
    def text_transform(self, x):
        return torch.tensor([self.sos_idx] + [ord(z)-97+3 for z in x] + [self.eos_idx])

现在我们将定义训练和评估步骤:

PAD_IDX = 0
SOS_IDX = 1
EOS_IDX = 2

def train(model, optimizer, loader, loss_fn, epoch):
    model.train()
    losses = 0
    acc = 0
    history_loss = []
    history_acc = [] 

    with tqdm(loader, position=0, leave=True) as tepoch:
        for x, y in tepoch:
            tepoch.set_description(f"Epoch {epoch}")

            optimizer.zero_grad()
            logits = model(x, y[:, :-1])
            loss = loss_fn(logits.contiguous().view(-1, model.vocab_size), y[:, 1:].contiguous().view(-1))
            loss.backward()
            optimizer.step()
            losses += loss.item()
            
            preds = logits.argmax(dim=-1)
            masked_pred = preds * (y[:, 1:]!=PAD_IDX)
            accuracy = (masked_pred == y[:, 1:]).float().mean()
            acc += accuracy.item()
            
            history_loss.append(loss.item())
            history_acc.append(accuracy.item())
            tepoch.set_postfix(loss=loss.item(), accuracy=100. * accuracy.item())

    return losses / len(list(loader)), acc / len(list(loader)), history_loss, history_acc


def evaluate(model, loader, loss_fn):
    model.eval()
    losses = 0
    acc = 0
    history_loss = []
    history_acc = [] 

    for x, y in tqdm(loader, position=0, leave=True):

        logits = model(x, y[:, :-1])
        loss = loss_fn(logits.contiguous().view(-1, model.vocab_size), y[:, 1:].contiguous().view(-1))
        losses += loss.item()
        
        preds = logits.argmax(dim=-1)
        masked_pred = preds * (y[:, 1:]!=PAD_IDX)
        accuracy = (masked_pred == y[:, 1:]).float().mean()
        acc += accuracy.item()
        
        history_loss.append(loss.item())
        history_acc.append(accuracy.item())

    return losses / len(list(loader)), acc / len(list(loader)), history_loss, history_acc

然后训练模型几个epoch:

import torch
import time
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

from tqdm import tqdm
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
from mpl_toolkits.axes_grid1 import ImageGrid


def collate_fn(batch):
    """ 
    This function pads inputs with PAD_IDX to have batches of equal length
    """
    src_batch, tgt_batch = [], []
    for src_sample, tgt_sample in batch:
        src_batch.append(src_sample)
        tgt_batch.append(tgt_sample)

    src_batch = pad_sequence(src_batch, padding_value=PAD_IDX, batch_first=True)
    tgt_batch = pad_sequence(tgt_batch, padding_value=PAD_IDX, batch_first=True)
    return src_batch, tgt_batch

# Model hyperparameters
args = {
    'vocab_size': 128,
    'model_dim': 128,
    'dropout': 0.1,
    'n_encoder_layers': 1,
    'n_decoder_layers': 1,
    'n_heads': 4
}

# Define model here
model = Transformer(**args)

# Instantiate datasets
train_iter = ReverseDataset(50000, pad_idx=PAD_IDX, sos_idx=SOS_IDX, eos_idx=EOS_IDX)
eval_iter = ReverseDataset(10000, pad_idx=PAD_IDX, sos_idx=SOS_IDX, eos_idx=EOS_IDX)
dataloader_train = DataLoader(train_iter, batch_size=256, collate_fn=collate_fn)
dataloader_val = DataLoader(eval_iter, batch_size=256, collate_fn=collate_fn)

# During debugging, we ensure sources and targets are indeed reversed
# s, t = next(iter(dataloader_train))
# print(s[:4, ...])
# print(t[:4, ...])
# print(s.size())

# Initialize model parameters
for p in model.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)

# Define loss function : we ignore logits which are padding tokens
loss_fn = torch.nn.CrossEntropyLoss(ignore_index=PAD_IDX)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.98), eps=1e-9)

# Save history to dictionnary
history = {
    'train_loss': [],
    'eval_loss': [],
    'train_acc': [],
    'eval_acc': []
}

# Main loop
for epoch in range(1, 4):
    start_time = time.time()
    train_loss, train_acc, hist_loss, hist_acc = train(model, optimizer, dataloader_train, loss_fn, epoch)
    history['train_loss'] += hist_loss
    history['train_acc'] += hist_acc
    end_time = time.time()
    val_loss, val_acc, hist_loss, hist_acc = evaluate(model, dataloader_val, loss_fn)
    history['eval_loss'] += hist_loss
    history['eval_acc'] += hist_acc
    print((f"Epoch: {epoch}, Train loss: {train_loss:.3f}, Train acc: {train_acc:.3f}, Val loss: {val_loss:.3f}, Val acc: {val_acc:.3f} "f"Epoch time = {(end_time - start_time):.3f}s"))

可视化注意力

我们定义了一个小函数来访问注意头的权重:

fig = plt.figure(figsize=(10., 10.))
images = model.decoder.decoder_blocks[0].cross_attention.attention_weigths[0,...].detach().numpy()
grid = ImageGrid(fig, 111,  # similar to subplot(111)
                nrows_ncols=(2, 2),  # creates 2x2 grid of axes
                axes_pad=0.1,  # pad between axes in inch.
                )

for ax, im in zip(grid, images):
    # Iterating over the grid returns the Axes.
    ax.imshow(im)

当从顶部读取权重时,我们可以看到一个很好的从右到左模式。由于填充蒙版,y轴底部的垂直部分肯定可以表示蒙版权重

测试我们的模型!


为了用新数据测试我们的模型,我们将定义一个小小的Translator类来帮助我们解码模型的输出:

class Translator(nn.Module):
    def __init__(self, transformer):
        super(Translator, self).__init__()
        self.transformer = transformer
    
    @staticmethod
    def str_to_tokens(s):
        return [ord(z)-97+3 for z in s]
    
    @staticmethod
    def tokens_to_str(tokens):
        return "".join([chr(x+94) for x in tokens])
    
    def __call__(self, sentence, max_length=None, pad=False):
        
        x = torch.tensor(self.str_to_tokens(sentence))
        x = torch.cat([torch.tensor([SOS_IDX]), x, torch.tensor([EOS_IDX])]).unsqueeze(0)
        
        encoder_output, mask = self.transformer.encode(x) # (B, S, E)
        
        if not max_length:
            max_length = x.size(1)
            
        outputs = torch.ones((x.size()[0], max_length)).type_as(x).long() * SOS_IDX
        
        for step in range(1, max_length):
            y = outputs[:, :step]
            probs = self.transformer.decode(y, encoder_output)
            output = torch.argmax(probs, dim=-1)
            print(f"Knowing {y} we output {output[:, -1]}")
            if output[:, -1].detach().numpy() in (EOS_IDX, SOS_IDX):
                break
            outputs[:, step] = output[:, -1]
            
        
        return self.tokens_to_str(outputs[0])

translator = Translator(model)

你应该能够看到以下内容:

结论

        就是这样,你现在可以编写Transformer并将其用于更大的数据集来执行机器翻译,例如创建您自己的BERT ! 我希望本教程向你展示编写Transformer时的注意事项:填充和屏蔽可能是最需要注意的部分,因为它们将在推理期间定义模型的良好性能。

  • 22
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值