AI大模型 | 深入理解Transformer核心原理

Transformer模型自2017年由Google的研究人员提出以来,已成为自然语言处理领域的革命性进展。其独特的自注意力机制使其在多种任务上表现卓越,包括机器翻译、文本生成和语义理解等。本篇博客旨在深入解析Transformer的核心原理和实现方式。

Transformer模型架构

Transformer完全基于注意力机制,去除了传统的递归和卷积层,主要由两大部分组成:编码器(Encoder)和解码器(Decoder)。

  1. 编码器: 编码器由N个相同的层堆叠而成,每层有两个子层。第一个子层是多头注意力机制,第二个子层是简单的全连接前馈网络。每个子层周围都有一个残差连接,后接一个层归一化。

  2. 解码器: 解码器也由N个相同的层组成,但在两个与编码器相同的子层外,还插入了一个第三子层,用于进行多头注意力机制的编码器-解码器交互。

自注意力机制

自注意力机制是Transformer的核心,它允许模型在序列的不同位置间直接交互和学习。每个注意力函数可以被描述为对一个查询(Query)和一组键值对(Keys and Values)的映射。具体计算公式如下:
在这里插入图片描述

多头注意力

多头注意力机制通过将注意力模型的查询、键和值矩阵拆分成多个头,可以让模型同时从不同的表示子空间学习信息:

在这里插入图片描述

前馈网络

每个编码器和解码器内的注意力层后面都跟随一个前馈网络,这个网络对每个位置都是分开同等处理的。这意味着相同位置的单词会通过相同的函数进行处理。通常情况下,前馈网络由两个线性变换和一个ReLU激活函数构成:

在这里插入图片描述

训练与优化

Transformer通常使用位置编码来使模型能够利用序列的顺序信息,采用Adam优化器以及基于注意力权重的学习率调整策略。这些设计让Transformer在处理长序列数据时表现出色。

实现示例

下面的代码是一个简化版的Transformer模型实现,使用PyTorch框架。这个实现包括了自注意力机制、多头注意力、位置编码、以及完整的编码器和解码器结构。

导入必要的库

import torch
import torch.nn as nn
import torch.optim as optim
import math

定义多头注意力模块

class MultiHeadAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(MultiHeadAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        assert self.head_dim * heads == embed_size, "Embedding size needs to be divisible by heads"

        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.fc_out = nn.Linear(heads * self.head_dim, embed_size)

    def forward(self, values, keys, query, mask):
        N = query.shape[0]
        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]

        # Split the embedding into self.heads different pieces
        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        queries = query.reshape(N, query_len, self.heads, self.head_dim)

        values = self.values(values)
        keys = self.keys(keys)
        queries = self.queries(queries)

        # Einsum does matrix multiplied by matrix product for query and keys and scale it
        attention = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
        attention = attention / (self.embed_size ** (1 / 2))
        if mask is not None:
            attention = attention.masked_fill(mask == 0, float('-1e20'))

        attention = torch.softmax(attention, dim=-1)

        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
            N, query_len, self.heads * self.head_dim
        )
        out = self.fc_out(out)
        return out

定义位置编码

class PositionalEncoding(nn.Module):
    def __init__(self, embed_size, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.encoding = torch.zeros(max_len, embed_size)
        pos = torch.arange(0, max_len).unsqueeze(1).float()
        _2i = torch.arange(0, embed_size, step=2).float()

        self.encoding[:, 0::2] = torch.sin(pos / (10000 ** (_2i / embed_size)))
        self.encoding[:, 1::2] = torch.cos(pos / (10000 ** (_2i / embed_size)))
        self.encoding = self.encoding.unsqueeze(0)

    def forward(self, x):
        return x + self.encoding[:, :x.size(1)].detach()

编码器和解码器层

class TransformerBlock(nn.Module):
    def __init__(self, embed_size, heads, dropout, forward_expansion):
        super(TransformerBlock, self).__init__()
        self.attention = MultiHeadAttention(embed_size, heads)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)

        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, forward_expansion * embed_size),
            nn.ReLU(),
            nn.Linear(forward_expansion * embed_size, embed_size),
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, value, key, query, mask):
        attention = self.attention(value, key, query, mask)
        x = self.dropout(self.norm1(attention + query))
        forward = self.feed_forward(x)
        out = self.dropout(self.norm2(forward + x))
        return out

构建整体Transformer

这部分代码初始化编码器和解码器层,以及其它必要的组件来完整构建一个Transformer模型。

class Transformer(nn.Module):
    def __init__(self,
                 src_vocab_size,
                 trg_vocab_size,
                 src_pad_idx,
                 trg_pad_idx,
                 embed_size=256,
                 num_layers=6,
                 forward

_expansion=4,
                 heads=8,
                 dropout=0.1,
                 max_length=100):
        super(Transformer, self).__init__()
        self.encoder = Encoder(src_vocab_size,
                               embed_size,
                               num_layers,
                               heads,
                               device,
                               forward_expansion,
                               dropout,
                               max_length)

        self.decoder = Decoder(trg_vocab_size,
                               embed_size,
                               num_layers,
                               heads,
                               forward_expansion,
                               dropout,
                               device,
                               max_length)

        self.src_pad_idx = src_pad_idx
        self.trg_pad_idx = trg_pad_idx
        self.device = device

    def make_src_mask(self, src):
        src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)
        return src_mask.to(self.device)

    def make_trg_mask(self, trg):
        N, trg_len = trg.shape
        trg_mask = torch.tril(torch.ones((trg_len, trg_len))).expand(
            N, 1, trg_len, trg_len
        )
        return trg_mask.to(self.device)

    def forward(self, src, trg):
        src_mask = self.make_src_mask(src)
        trg_mask = self.make_trg_mask(trg)
        enc_src = self.encoder(src, src_mask)
        out = self.decoder(trg, enc_src, src_mask, trg_mask)
        return out

这个实现包括了构建整个模型所需的主要部件。您可以根据需要进一步添加具体的编码器和解码器实现,以及完整的模型训练和评估逻辑。

结论

Transformer模型以其强大的处理能力和灵活性,在多个领域都表现出了优越的性能,尤其是在自然语言处理领域。通过理解其内部工作原理,研发人员可以更好地利用这一模型解决各类复杂的序列数据任务。

如有侵权,请联系删除。


最后分享

  • 23
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值