手撕SwiGLU和GELU

GELU(Gaussian Error Linear Unit):
  • 公式
    GELU ( x ) = x ⋅ Φ ( x ) = x ⋅ 1 2 ( 1 + erf ( x 2 ) ) \text{GELU}(x) = x \cdot \Phi(x) = x \cdot \frac{1}{2}\left(1 + \text{erf}\left(\frac{x}{\sqrt{2}}\right)\right) GELU(x)=xΦ(x)=x21(1+erf(2 x))

  • 近似公式(在实践中经常使用的版本):
    GELU ( x ) ≈ 0.5 ⋅ x ⋅ ( 1 + tanh ⁡ ( 2 π ⋅ ( x + 0.044715 ⋅ x 3 ) ) ) \text{GELU}(x) \approx 0.5 \cdot x \cdot \left(1 + \tanh\left(\sqrt{\frac{2}{\pi}} \cdot (x + 0.044715 \cdot x^3)\right)\right) GELU(x)0.5x(1+tanh(π2 (x+0.044715x3)))

SwiGLU(Swish-Gated Linear Unit):
  • 公式
    SwiGLU ( x ) = σ ( Linear ( x 1 ) ) ⋅ Swish ( Linear ( x 2 ) ) \text{SwiGLU}(x) = \sigma(\text{Linear}(x_1)) \cdot \text{Swish}(\text{Linear}(x_2)) SwiGLU(x)=σ(Linear(x1))Swish(Linear(x2))
    其中,Swish 是一个平滑的激活函数:
    Swish ( x ) = x ⋅ σ ( x ) = x 1 + e − x \text{Swish}(x) = x \cdot \sigma(x) = \frac{x}{1 + e^{-x}} Swish(x)=xσ(x)=1+exx
GELU 实现(PyTorch 内置):
import torch
import torch.nn as nn

# GELU 激活函数 (PyTorch 内置)
gelu = nn.GELU()

# 输入张量
x = torch.randn(2, 5)
output = gelu(x)
print(output)
import torch
import torch.nn as nn

class GELUApprox(nn.Module):
    def forward(self, x):
        # GELU 近似实现
        return 0.5 * x * (1 + torch.tanh(torch.sqrt(torch.tensor(2.0 / torch.pi)) * (x + 0.044715 * x ** 3)))

# 示例
x = torch.randn(2, 5)
gelu_approx = GELUApprox()
output = gelu_approx(x)
print(output)

SwiGLU 实现:
import torch
import torch.nn as nn

class SwiGLU(nn.Module):
    def __init__(self, d_model):
        super(SwiGLU, self).__init__()
        # 两个线性层,用于将输入拆分成两部分
        self.linear1 = nn.Linear(d_model, d_model)
        self.linear2 = nn.Linear(d_model, d_model)
    
    def forward(self, x):
        return torch.sigmoid(self.linear1(x)) * torch.nn.functional.silu(self.linear2(x))  # SiLU 是 Swish 的实现

# 输入张量
x = torch.randn(2, 5)

# SwiGLU 激活函数
swiglu = SwiGLU(d_model=5)
output = swiglu(x)
print(output)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值