一文搞懂 Transformer 里的前馈神经网络:从代码实现到对大模型的深入理解

引言

在之前的内容中,我们深入探讨了 Transformer 的核心 ——Attention 机制,一文带你彻底搞懂!Transformer !!_show vlan int-CSDN博客 无论是缩放、批量、多头、掩码还是交叉注意力机制,都对模型处理信息的能力起到了关键作用。然而,在 Transformer 模型的整体架构里,FFN(Feed - Forward Network,前馈神经网络)层同样占据着举足轻重的地位,它与 Attention 机制相辅相成,共同支撑起模型的强大功能。

本文将全方位、细致地解析 FFN 层。从它的结构组成开始,一步步深入到数学原理的推导,再到源码层面的理解,同时结合其在大模型中的实际应用,详细阐述 FFN 层是如何通过升维和降维操作,有效增强模型的表达能力,以及它在 Transformer 中发挥的独特作用。希望通过本文,能让大家透彻了解 Transformer 的 FFN 层。文章将按照以下逻辑展开:先介绍 FFN 的架构,接着推导其数学表示,然后深入理解源码,再分析 FFN 层的主要作用,从 KV 键值角度进一步解读,探讨 FFN 与 Attention 的非线性特点,解释 FFN 升维到 4d 的原因,最后介绍大模型 FFN 层常用的激活函数 SwiGLU。后续还会持续分享更多大模型、AIGC、Agent、RAG 等学术前沿内容!


FFN 架构

FFN 是 Transformer 中不可或缺的重要组件。我们来看 Transformer 的整体架构,它主要由 Encoder(编码器)和 Decoder(解码器)两大部分构成。在这个架构里,前馈神经网络,也就是 FFN,又被叫做全连接层或密集层,我们用红色标记出来的部分就是它。在 Transformer 模型运行时,FFN 层通常紧跟在编码器和解码器的注意力层后面,在注意力层对输入数据完成关键的特征提取和关联分析后,FFN 层接力对数据进行进一步处理 。

下图我将 FFN 层单独展开来看,它主要包含三个核心部分。首先是「升维线性变换」阶段,输入数据进入 FFN 层后,会先经历一个线性变换过程。这个过程就像是给数据换个 “大一点的容器”,通过一个特定的权重矩阵和偏置向量(不过在一些大模型架构里,也会把偏置向量去掉),将输入数据映射到更高维度的空间。这样做的目的是为了让模型能够在更丰富的维度上捕捉数据的特征和规律。

接着是「非线性激活函数」部分。数据经过升维线性变换后,并不会直接进入下一个环节,而是要通过一个激活函数。激活函数的作用是打破数据的线性关系,为模型引入非线性因素,大大增强模型对复杂数据的表达和处理能力。常见的激活函数有 ReLU、Sigmoid、Tanh 等,在当下主流的大模型中,SwiGLU 是更常用的激活函数。

最后是「降维线性变换」。数据经过激活函数处理后,还要再进行一次线性变换。这次会利用一个降维矩阵,把数据从高维空间重新映射回原始维度,得到 FFN 层最终的输出结果。这一系列操作看似复杂,但每一步都有其特定的功能,共同为 Transformer 模型的强大性能贡献力量。


FFN 数学表示

为了更清晰地理解 FFN 层的工作原理,我们从数学计算的角度来看看它是如何运行的。假设输入数据为 

\mathbf{x} \in \mathbb{R}^{b \times d}

这里的 b 代表数据的批量大小,也就是一次处理多少个数据样本;d 表示隐藏维度,即数据在模型内部的特征维度数量。当我们采用 ReLU 激活函数时,FFN 层具体的计算过程是这样的:

首先,输入 \mathbf{x} 会与第一层的权重矩阵 \mathbf{W}_1 \in \mathbb{R}^{d \times 4d} 进行矩阵乘法运算,再加上第一层的偏置项 \mathbf{b}_1 \in \mathbb{R}^{4d},这个过程就是前面提到的升维线性变换,通常会把输入数据的维度扩展 4 倍。用公式表示为:

\mathbf{z} = \mathbf{x} \mathbf{W}_1 + \mathbf{b}_1

其中,\mathbf{z} \in \mathbb{R}^{b \times 4d}

然后,将得到的结果 \mathbf{z} 传入 ReLU 非线性激活函数中,激活函数会对数据进行处理,引入非线性因素。ReLU 函数的定义为:

\text{ReLU}(x) = \max(0, x)

对 \mathbf{z} 中的每个元素应用 ReLU 函数,得到:

 \mathbf{h} = \text{ReLU}(\mathbf{z})

其中,\mathbf{h} \in \mathbb{R}^{b \times 4d}

最后,经过激活函数处理的数据 \mathbf{h},会与第二层的权重矩阵 \mathbf{W}_2 \in \mathbb{R}^{4d \times d} 相乘,再加上第二层的偏置项 \mathbf{b}_2 \in \mathbb{R}^{d},这一步实现了降维线性变换,将数据维度变回原来的 d 大小,得到最终的输出 \mathbf{y}。用公式表示为:

\mathbf{y} = \mathbf{h} \mathbf{W}_2 + \mathbf{b}_2

其中,\mathbf{y} \in \mathbb{R}^{b \times d}

不过,在目前很多大模型的实际应用中,为了简化计算或者基于特定的优化考虑,会把偏置项去掉。当去掉偏置项,仍然采用 ReLU 激活函数时,FFN 的计算过程就变成了: 首先,输入 \mathbf{x} 直接与第一层的权重矩阵 \mathbf{W}_1 进行矩阵乘法运算:

\mathbf{z} = \mathbf{x} \mathbf{W}_1

然后经过 ReLU 激活函数处理:

\mathbf{h} = \text{ReLU}(\mathbf{z})

最后与第二层的权重矩阵 \mathbf{W}_2 相乘,得到最终的输出 \mathbf{y}

\mathbf{y} = \mathbf{h} \mathbf{W}_2

通过这些数学公式,我们能更准确地把握 FFN 层对数据的处理逻辑和变换过程。 


FFN 源码理解

接下来,我们通过 PyTorch 代码来深入理解 Transformer 编码器层中 FFN 的具体实现。这段代码构建了 Transformer 编码器层的关键组件,其中 FFN 部分的核心逻辑集中在两个全连接层上,它们分别对应前面数学表示中提到的矩阵 \mathbf{W}_1 和 \mathbf{W}_2

首先,我们看到代码中定义了 MultiHeadAttention 类,这是多头自注意力机制的实现。多头自注意力机制在 Transformer 中负责捕捉输入数据不同部分之间的关联关系。在这个类的 __init__ 方法里,初始化了用于将输入映射到查询(Query)、键(Key)和值(Value)的线性变换矩阵 self.wqself.wkself.wv ,以及用于输出的线性变换矩阵 self.fc_out 。在 forward 方法中,通过一系列操作计算出注意力得分,经过 Softmax 函数转换为概率分布,最后与值向量相乘得到输出。

import torch
import torch.nn as nn

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super(MultiHeadAttention, self).__init__()
        self.n_heads = n_heads  # 多头注意力的头数
        self.d_model = d_model  # 输入维度(模型的总维度)
        self.head_dim = d_model // n_heads  # 每个注意力头的维度
        assert self.head_dim * n_heads == d_model, "d_model必须能够被n_heads整除"# 断言,确保d_model可以被n_heads整除

        # 线性变换矩阵,用于将输入向量映射到查询、键和值空间
        self.wq = nn.Linear(d_model, d_model)  # 查询(Query)的线性变换
        self.wk = nn.Linear(d_model, d_model)  # 键(Key)的线性变换
        self.wv = nn.Linear(d_model, d_model)  # 值(Value)的线性变换

        # 最终输出的线性变换,将多头注意力结果合并回原始维度
        self.fc_out = nn.Linear(d_model, d_model)  # 输出的线性变换

    def forward(self, query, key, value, mask):
        # 将嵌入向量分成不同的头
        query = query.view(query.shape[0], -1, self.n_heads, self.head_dim)
        key = key.view(key.shape[0], -1, self.n_heads, self.head_dim)
        value = value.view(value.shape[0], -1, self.n_heads, self.head_dim)

        # 转置以获得维度 batch_size, self.n_heads, seq_len, self.head_dim
        query = query.transpose(1, 2)
        key = key.transpose(1, 2)
        value = value.transpose(1, 2)

        # 计算注意力得分
        scores = torch.matmul(query, key.transpose(-2, -1)) / self.head_dim
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        attention = torch.nn.functional.softmax(scores, dim=-1)

        out = torch.matmul(attention, value)

        # 重塑以恢复原始输入形状
        out = out.transpose(1, 2).contiguous().view(query.shape[0], -1, self.d_model)

        out = self.fc_out(out)
        return out

然后是TransformerEncoderLayer 类,它定义了 Transformer 的编码器层。在这个类的 __init__ 方法中,初始化了多头自注意力层 self.self_attn ,以及 FFN 层的两个关键组件: self.linear1 和 self.linear2 self.linear1 负责将输入从 d_model 维度映射到 dim_feedforward 维度,实现升维操作;self.linear2 则将 dim_feedforward 维度映射回 d_model 维度,完成降维。此外,还定义了 self.dropout 用于防止过拟合,以及两个层归一化层 self.norm1 和 self.norm2 ,用于稳定网络训练。 

class TransformerEncoderLayer(nn.Module):
    def __init__(self, d_model, n_heads, dim_feedforward, dropout):
        super(TransformerEncoderLayer, self).__init__()
        
        # 多头自注意力层,接收d_model维度输入,使用n_heads个注意力头
        self.self_attn = MultiHeadAttention(d_model, n_heads)
        
        # 第一个全连接层,将d_model维度映射到dim_feedforward维度
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        
        # 第二个全连接层,将dim_feedforward维度映射回d_model维度
        self.linear2 = nn.Linear(dim_feedforward, d_model)
        
        # 用于随机丢弃部分神经元,以减少过拟合
        self.dropout = nn.Dropout(dropout)
        
        # 第一个层归一化层,用于归一化第一个全连接层的输出
        self.norm1 = nn.LayerNorm(d_model)
        
        # 第二个层归一化层,用于归一化第二个全连接层的输出
        self.norm2 = nn.LayerNorm(d_model)

在 TransformerEncoderLayer 类的 forward 方法中,详细展示了数据在编码器层中的流动过程。首先,输入 src 经过多头自注意力层 self.self_attn 处理,得到 src2 。然后通过残差连接和层归一化,将自注意力层的输出与原始输入结合并归一化。接着,数据进入 FFN 层:先通过 self.linear1 进行升维,再经过 torch.nn.functional.relu 激活函数引入非线性,之后由 self.linear2 完成降维,中间还使用 self.dropout 进行随机丢弃操作。最后,再次通过残差连接和层归一化,输出最终结果。 

    def forward(self, src, src_mask):
        # 使用多头自注意力层处理输入src,同时提供src_mask以屏蔽不需要考虑的位置
        src2 = self.self_attn(src, src, src, src_mask)
        
        # 残差连接和丢弃:将自注意力层的输出与原始输入相加,并应用丢弃
        src = src + self.dropout(src2)
        
        # 应用第一个层归一化
        src = self.norm1(src)

        # 经过第一个全连接层,再经过激活函数ReLU,然后进行丢弃
        src2 = self.linear2(self.dropout(torch.nn.functional.relu(self.linear1(src))))
        
        # 残差连接和丢弃:将全连接层的输出与之前的输出相加,并再次应用丢弃
        src = src + self.dropout(src2)
        
        # 应用第二个层归一化
        src = self.norm2(src)

        # 返回编码器层的输出
        return src

下述代码是通过model = Transformer(vocab_size, d_model, n_heads, num_encoder_layers, dim_feedforward, max_seq_length, dropout)完成模型的实例化,此时model就是一个完整的 Transformer 模型对象,可以用于后续的训练和推理 ,例如接收合适的输入数据张量和对应的掩码张量,调用model对象并传入数据,就能得到模型的输出结果。 

# 定义Transformer模型
class Transformer(nn.Module):
    def __init__(self, vocab_size, d_model, n_heads, num_encoder_layers, dim_feedforward, max_seq_length, dropout):
        super(Transformer, self).__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)  # 嵌入层,将单词索引转换为向量
        self.position_encoding = nn.Parameter(torch.zeros(1, max_seq_length, d_model))  # 位置编码
        encoder_layer = TransformerEncoderLayer(d_model, n_heads, dim_feedforward, dropout)
        self.encoder = nn.Sequential(*[encoder_layer for _ in range(num_encoder_layers)])  # 堆叠编码器层

    def forward(self, x, mask):
        x = self.embedding(x)  # 嵌入操作
        x = x + self.position_encoding[:, :x.size(1), :]  # 添加位置编码
        x = self.encoder(x, mask)  # 通过编码器层
        return x


vocab_size = 10000  # 词汇表大小(根据实际情况调整)
d_model = 512  # 模型的维度
n_heads = 8  # 多头自注意力的头数
num_encoder_layers = 6  # 编码器层的数量
dim_feedforward = 2048  # 全连接层的隐藏层维度
max_seq_length = 100  # 最大序列长度
dropout = 0.1  # 丢弃率

# 创建Transformer模型实例
model = Transformer(vocab_size, d_model, n_heads, num_encoder_layers, dim_feedforward, max_seq_length, dropout)
print(model)

  


FFN 层作用

从 FFN 具体各层的角度来看:

升维层当输入数据进入升维层时,它会通过一个线性变换操作,将数据映射到更高的维度空间。在低维空间里,数据特征之间的关系可能比较模糊,难以充分表达数据的复杂结构。而升高向量维度,就好比给数据提供了更广阔的 “舞台”,让模型能够挖掘出更复杂的特征关系。因为在高维空间中,向量之间有更多的 “自由度”,原本在低维难以区分的数据,在高维下可能就会变得容易区分,模型可以更好地捕捉到数据之间细微的差异和联系。

下面是一个特别直观的例子,在2维空间中红蓝两色的点不好区分,但将其映射到3维空间,就能够比较容易的进行区分。尽管当前模型动不动就是几百上千维,原理是是一致的。

激活层激活层在 FFN 中扮演着至关重要的角色。在神经网络里,每个神经元的输出通常是输入的加权和,如果没有激活函数,无论神经网络有多少层,最终输出都只是输入的线性组合。这样的线性模型能表达的关系非常有限,无法学习现实世界中复杂的非线性关系。而激活函数的引入,打破了这种线性关系,让模型能够学习和拟合各种复杂的函数曲线,从而大大增强了模型的表达能力。

降维层虽然升维操作可以帮助模型捕捉更多信息,但过高的维度会带来计算量的大幅增加,还可能导致模型过拟合。降维层的作用就是去除冗余信息,将高维特征浓缩成更精炼的表示。它通过特定的线性变换,把数据从高维空间映射回原始或较低维度空间,既控制了模型的复杂度和计算成本,又确保了 FFN 的输出维度与输入维度一致,方便后续神经网络层的处理和连接。

从整体 FFN 层的结构来看:

1、维度扩展和特征抽取FFN 的第一层全连接层 self.linear1 将输入维度从 d_model 扩展到 dim_feedforward ,这相当于增加了模型的容量。在更高的维度上,模型有更多的 “空间” 去处理和提取信息。而第二层全连接层 self.linear2 又将维度降回 d_model ,这样既不会使模型参数过多导致训练困难,又实现了对信息的有效压缩和提取,保证模型在学习到足够特征的同时保持高效。

2、引入非线性变换Transformer 中的注意力机制本质上是线性的,它主要通过计算不同 token(可以理解为输入数据的基本单元)之间的加权和来捕捉关系。这种线性操作虽然能处理一些简单的关系,但对于复杂的特征和关系表达能力有限。而 FFN 通过激活函数引入非线性变换,弥补了自注意力机制的这一局限性,让模型能够学习和表达更复杂的特征和关系

3、位置独立处理在 Transformer 架构中,位置编码一般会在注意力阶段进行,不同的模型采用不同的方式添加位置编码,比如 Transformer 用正余弦函数添加,Bert 用可学习的方式添加,现在很多生成式大模型用 RoPE 方法添加。而 FFN 层在处理数据时,不会引入额外的位置信息,它对每个位置的特征向量进行独立的非线性变换。这样一来,FFN 层可以专注于增强每个位置自身的特征,不会受到其他位置信息的干扰。这与自注意力机制的全局交互特性形成了很好的互补,使得模型既能捕捉到每个位置的局部特征,又能处理好全局的依赖关系。

4、下游任务匹配Transformer 模型的设计初衷之一,就是能够灵活应用于各种不同的任务,像自然语言处理、计算机视觉等领域。FFN 层的结构相对简单,但通过调整它的参数,比如隐藏层的维度、使用的激活函数等,就可以很方便地改变模型的表达能力和复杂度。这种灵活性让 Transformer 模型能够根据不同任务的需求进行调整,更好地适应各种实际应用场景。


从键值对 (KV) 理解 FFN

在前面的内容中,我们了解到 FFN(前馈神经网络)在模型中承担着存储训练数据知识的重要角色,这一观点不仅符合我们的直观认知,也与当前主流的学术研究方向一致。例如,在 MoE(混合专家)架构中,专家模型以及通用模型的实现都依赖于 FFN 层;在大模型的 Adapter 微调中,新增的 Adapter 模块同样由 FFN 层构成;即便是在大模型 LoRA 微调时,旁侧添加的 A、B 矩阵,本质上也是 FFN 层的一部分。

为了更深入地理解 FFN 层的工作原理,一篇名为《Transformer Feed - Forward Layers Are Key - Value Memories》的文章提出了一个非常有趣的观点:FFN 所学习到的知识,是以键值对(KV Memory)的形式存储在自身结构中的。其中,“键(Key)” 存储着从训练数据中提取出的特定特征,而对应的 “值(Value)” 则记录了在这些特征下预测下一个词的概率分布。

回顾我们之前对 Attention 机制的介绍,Attention 的基本计算公式为:

\text{Attention}(Q, K, V) = \text{softmax}(\frac{QK^T}{\sqrt{d_k}})V

 其中,Q(Query)为查询向量,K(Key)为键向量,V(Value)为值向量,d_k 是键向量的维度。而 FFN 的计算公式为(假设输入为 \mathbf{x},第一层权重矩阵为 \mathbf{W}_1,第二层权重矩阵为 \mathbf{W}_2,激活函数为 \sigma

\mathbf{y} = \mathbf{W}_2 \sigma(\mathbf{W}_1 \mathbf{x})

 如果我们将 FFN 中的 \mathbf{W}_1 矩阵看作 K,\mathbf{W}_2 矩阵看作 V,把激活函数(如 ReLU)类比为 softmax,结合下图进行对比分析:

那么 FFN 的计算公式可以重新写成类似 Attention 的形式:

这里,输入 \mathbf{x} 与每个 \mathbf{k}_i\mathbf{W}_1 中的行向量,类比为 K 中的元素)点乘得到系数(memory coefficient),然后与对应的 \mathbf{v}_i\mathbf{W}_2 中的行向量,类比为 V 中的元素)加权求和,最终得到 FFN 的输出。从形式上看,这个过程与 QKV 点乘注意力的计算非常相似。

不过,FFN 与 Attention 在计算过程中也存在明显的区别:

  1. 上下文依赖关系不同:Attention 的计算是上下文相关(context - dependent)的,其中的 Q、K、V 都来自于数据的表示,并且会随着输入数据的不同而变化,长度不固定;而 FFN 中的 K 和 V 是上下文无关(context - independent)的,Q 来自数据表示,K 和 V 则分别来自于两个可学习的参数矩阵中的向量,在模型训练完成后,它们的数值是固定的 。
  2. 激活函数特性不同:Attention 中的激活函数是具有归一化性质的 softmax 函数,它会将计算得到的注意力分数转换为概率分布;而 FFN 使用的是未归一化的激活函数,比如 ReLU 函数,它直接对输入进行非线性变换 。

文章作者通过将 FFN 与 Attention 进行对比,深入分析了 FFN 的知识存储机制和内部信息流动过程。实验结果表明,FFN 确实能够将训练数据中的一些模式或知识记忆并存储起来。从这个角度来讲,Attention 主要负责提取短期的上下文信息,而 FFN 则专注于对整个训练样本的信息进行提取和记忆。这也就解释了为什么即便在使用有限窗口甚至对语料进行截断的情况下,模型依然能够记住语料库中的信息。


FFN 与 Attention 的非线性

前面我们提到,FFN 的主要作用之一是引入非线性变换,目的是让模型能够学习和拟合复杂的函数关系,从而提升模型的表达能力。然而,通过上一节的对比,我们发现 Attention 机制也会通过 softmax 函数引入非线性变换。那么,这里就产生了一个问题:既然 Attention 已经引入了非线性,FFN 为什么还要再次引入非线性变换呢?这个问题在面试中经常被用来考察候选人对 Transformer 模型细节的理解程度。

要回答这个问题,我们首先需要明确另一个关键问题:Attention 机制是线性运算还是非线性运算?从全局角度来看,对于输入数据而言,Attention 是一种非线性运算。观察 Attention 的计算公式,其中确实包含针对 QK^T 的 softmax 非线性运算。但是,对于值向量 V 来说,在计算过程中并没有引入额外的非线性变换,每一次 Attention 的计算本质上相当于对 V 代表的向量进行加权平均。

此外,Attention 机制中的非线性变换(如 Softmax)和 FFN 中的非线性变换(如 ReLU)有着本质区别。Softmax 函数主要用于将注意力分数进行归一化处理,使其转换为概率分布,从而确定不同信息的权重;而 ReLU 等激活函数的作用是通过引入非线性,对输入特征进行 “重塑”,让模型能够学习到更加复杂、丰富的特征表示 。具体来说,Softmax 是对多个数值进行归一化,使得它们的和为 1,突出数值之间的相对大小关系;而 ReLU 函数则是直接对每个数值进行操作,当数值大于 0 时保持不变,小于 0 时则置为 0,这种操作打破了输入与输出之间的线性关系,让模型能够捕捉到数据中更复杂的模式。


FFN 为什么升维 4d?

在 FFN 中,第一层的权重矩阵 \mathbf{W}_1 一般是 d \times 4d 的维度(假设输入维度为 d),根据线性代数的知识,矩阵的秩(Rank)最高为 d。这意味着在 \mathbf{W}_1 中,至少有 3d 行可以由其他行通过线性组合表示出来。

假设某个输入 \mathbf{x} 在经过 \mathbf{W}_1 矩阵以及激活函数后的隐藏状态为 \mathbf{h},那么 FFN 的输出 \mathbf{y} 可以表示为对 \mathbf{h} 的 4d 行进行加权求和的过程,用公式表示为:

\mathbf{y} = \sum_{i = 1}^{4d} \mathbf{w}_{2, i}^T \mathbf{h}_i

 其中,\mathbf{w}_{2, i}^T 代表 \mathbf{W}_2 的第 i 行,\mathbf{h}_i 代表 \mathbf{h} 的第 i 个分量。

由于 \mathbf{W}_1 的低秩特性,我们一定可以在 \mathbf{W}_1 中找到 n (n \leq d)个线性无关的行(假设就是前 n 行)。对于任意的 \mathbf{h},都可以由这 n 行通过线性组合表示出来,即:

\mathbf{h} = \sum_{j = 1}^{n} \alpha_j \mathbf{w}_{1, j}^T \mathbf{x}

其中,\mathbf{w}_{1, j}^T 是 \mathbf{W}_1 中线性无关的行向量,\alpha_j 是对应的系数。

那么,上述 FFN 输出的加权求和公式就可以改写为:

\mathbf{y} = \sum_{i = 1}^{4d} \mathbf{w}_{2, i}^T \left( \sum_{j = 1}^{n} \alpha_j \mathbf{w}_{1, j}^T \mathbf{x} \right)

从这个角度来看,似乎我们只需要保留 \mathbf{W}_1 中这 n 个线性无关的行,模型只需要学习如何通过这部分行向量和激活函数得到 “新的” 特征表示,这样一来,FFN 中先升维再降维的操作看起来就显得多余了。那么,为什么在实际应用中,降低 FFN 中矩阵的维度会导致模型效果变差呢?

很多人会回答说升维是为了将数据投影到高维空间,增强信息的分离能力,而降维是为了与下一步计算的维度保持一致。并且,选择将维度提升 4 倍是基于经验,因为这样在计算成本、参数规模和表达能力之间达到了较好的平衡。但实际上,这种回答并没有触及问题的本质。FFN 升维的根本原因与非线性激活函数密切相关,因为非线性激活函数会导致大约 50% 的信息丢失,而在使用激活函数前进行升维,就是为了补偿这部分损失。

我们知道,激活函数的本质是通过 “舍弃” 部分信息,使数据的结构呈现出非线性。从随机信号处理的角度来看,信息丢失 50% 是一个较为理想的数值:

  • 如果信息丢失过少(小于 50%),神经元的非线性表达能力就会不足,模型难以学习到复杂的函数关系;
  • 如果信息丢失过多(大于 50%),则可能导致有效信息不足,使得模型在训练过程中变得不稳定 。

由于 FFN 采用的是两层全连接网络结构:

  1. 第一层(升维):将输入维度从 d 提升到 4d;
  2. 经过激活函数:以 ReLU 函数为例,大约会有一半的神经元输出变为 0,这意味着信息丢失了 1/2;
  3. 第二层(降维):将维度从 4d 降回到 d,在这个过程中,等效于有效信息又减少了近 1/2。

经过这样的过程,最终有效信息流减少到原来的 \frac{1}{4}。如果我们希望最终的有效信息能够保持与原始输入相同的水平,就必须提前进行升维操作来补偿信息损失。因此,将维度提升 4 倍就成为了一个自然的选择。


大模型 FFN 的 SwiGLU 激活函数

在前面讨论 FFN 层时,我们提到了激活函数对于引入非线性、提升模型表达能力的重要性。现在,我们来详细了解一下当前主流大模型在 FFN 层广泛使用的激活函数 ——SwiGLU(Swish-Gated Linear Unit)。SwiGLU 最早由 Google DeepMind 在论文《Scaling Laws for Neural Language Models》中提出。它巧妙地融合了 Swish 激活函数和 Gated Linear Unit(GLU,门控线性单元)机制,相比传统的 ReLU 函数,在大多数评估任务中都展现出显著优势,能够有效增强神经网络的表达能力,同时提升训练效率。两个激活函数如下图:

为了更好地理解 SwiGLU,我们先来分别认识它的两个关键组成部分:Swish 激活函数和 GLU 机制。

Swish 激活函数

Swish 是一种平滑的非单调激活函数,其数学定义为:

\text{Swish}(x) = x \cdot \sigma(x)

其中,\sigma(x) 是 Sigmoid 函数,公式为

 \sigma(x) = \frac{1}{1 + e^{-x}} 。

Sigmoid 函数会将任意实数压缩到 0 到 1 之间,当这个结果再与输入 x 相乘时,就得到了 Swish 函数的输出。

从图像上看,当参数 \beta(在标准 Swish 函数中默认为 1)趋近于 0 时,Swish 函数的曲线会越来越接近线性函数 y = x ;当 \beta 趋近于无穷大时,Swish 函数则趋近于 ReLU 函数 y = \max(0, x);而当 \beta 取值为 1 时,Swish 函数呈现出光滑且非单调的形态。在 HuggingFace 的 Transformer 库中,通常使用 silu 函数来替代 Swish 函数,它们在数学表达上是等价的。

研究发现,Swish 函数相比 ReLU 函数具有明显优势。ReLU 函数在输入小于 0 时输出直接置为 0,这种 “硬截断” 的特性可能导致一些信息丢失,并且在零点处不可导,会影响模型优化。而 Swish 函数在零点附近非常平滑,这使得模型在训练过程中能够更高效地进行梯度计算和参数更新,有助于实现更好的优化效果,加快模型收敛速度。

GLU 机制

GLU,即门控线性单元,它的核心思想是通过引入门控机制来控制信息流。GLU 的计算公式为:

\text{GLU}(x) = (x \cdot W_1 + b_1) \odot (x \cdot W_2 + b_2)

其中,x 是输入张量,W_1W_2 是可训练的权重矩阵,b_1b_2 是对应的偏置向量,\odot 表示逐元素乘法。GLU 通过两组不同的权重和偏置对输入进行变换,然后将这两个变换结果逐元素相乘,以此来动态调节信息的传递。简单理解,就像是给信息流安装了一个 “阀门”,可以根据输入数据的特征,灵活决定有多少信息能够通过。

SwiGLU 的原理与实现

将 Swish 激活函数和 GLU 机制结合,就得到了 SwiGLU 激活函数。它的数学表达式为:  

\text{SwiGLU}(x) = \text{Swish}(x \cdot W_1 + b_1) \cdot (x \cdot W_2 + b_2)

在实际应用中,为了保证前馈网络(FFN)的参数总量不变,通常会将 FFN 隐藏层的维度设置为原来的两倍。这是因为 SwiGLU 引入了额外的线性变换(对应公式中的 x \cdot W_2 + b_2 部分),为了平衡参数数量,就需要扩大隐藏层维度来补偿,确保模型的复杂度和计算量处于合理范围。

下面是使用 PyTorch 实现 SwiGLU 激活函数的代码:

import torch
import torch.nn as nn

class SwiGLU(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        # 定义第一个线性变换,对应公式中的 x * W_1 + b_1
        self.w1 = nn.Linear(input_dim, hidden_dim)
        # 定义第二个线性变换,对应公式中的 x * W_2 + b_2
        self.w2 = nn.Linear(input_dim, hidden_dim)
    
    def forward(self, x):
        # 计算 Swish 部分:torch.sigmoid(self.w1(x))
        # 与第二个线性变换结果逐元素相乘
        return torch.sigmoid(self.w1(x)) * self.w2(x)

# 示例
x = torch.randn(4, 10)  # 生成4个样本,每个样本10维的随机张量
swiglu = SwiGLU(10, 20)  # 初始化SwiGLU,输入维度10,隐藏维度20
output = swiglu(x)  # 将输入张量传入SwiGLU
print(output.shape)  # 输出张量形状,应该是 torch.Size([4, 20])

 

在这段代码中,SwiGLU 类的 __init__ 方法初始化了两个线性层 self.w1 和 self.w2 ,分别对应 SwiGLU 公式中的两组线性变换。forward 方法则按照公式实现了具体计算过程:先对 self.w1 的输出应用 Sigmoid 函数得到 Swish 部分,再与 self.w2 的输出进行逐元素相乘,得到最终结果。

SwiGLU 与其它激活函数的对比

为了更直观地了解 SwiGLU 的优势,我们将它与其他常见激活函数(如 ReLU、GELU 等)进行对比:

激活函数公式特点优势不足
ReLUy = \max(0, x)简单高效,计算速度快;在正半轴线性,负半轴输出 0缓解梯度消失问题;广泛应用,易于实现负半轴输出恒为 0,导致信息丢失;零点不可导,影响优化
GELUy = x \cdot \Phi(x) ,\Phi(x) 为高斯累积分布函数的近似平滑连续,根据输入分布自适应调节表达能力强,能更好捕捉数据特征计算复杂度相对较高
SwiGLU\text{SwiGLU}(x) = \text{Swish}(x \cdot W_1 + b_1) \cdot (x \cdot W_2 + b_2)融合 Swish 和 GLU,动态控制信息流增强模型表达能力;训练效率高;在长序列任务中表现出色增加了参数和计算量(通过调整隐藏层维度平衡)

通过对比可以看出,SwiGLU 激活函数凭借独特的设计,在提升模型性能和训练效率上具有显著优势,这也是它在当前主流大模型的 FFN 层中被广泛采用的重要原因。


评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值