Swish 激活函数
Swish 激活函数在 PyTorch 中对应的是 torch.nn.functional.silu
函数。silu
表示 Sigmoid-Weighted Linear Unit,它是 Swish 的别名。
特点:
- 平滑性: Swish 是一个平滑的函数,相较于 ReLU,Swish 的曲线没有不连续点。
- 非单调性: Swish 在负值附近有一个小的负斜率区域,有助于神经网络捕获更复杂的模式。
- 可微性: Swish 是可微的,这对于梯度下降优化算法非常友好。
PyTorch 中的实现:
PyTorch 直接提供了 torch.nn.SiLU
(Sigmoid Linear Unit),这与 Swish 等价。
import torch
import torch.nn.functional as F
# 方法 1: 使用 PyTorch 提供的 SiLU
swish = torch.nn.SiLU()
# 方法 2: 手动实现 Swish
def swish_manual(x):
return x * torch.sigmoid(x)
# 示例
x = torch.tensor([-1.0, 0.0, 1.0])
print(swish(x)) # PyTorch 内置实现
print(swish_manual(x)) # 手动实现
SwiGLU 激活函数
SwiGLU 是一种基于门控线性单元(GLU, Gated Linear Unit)的激活函数。它由以下公式定义:
其中:
- W1,W2 是两个线性变换。
- Swish 是 Swish 激活函数。
特点:
- 门控机制: 使用了乘法操作,将特征选择与激活函数结合起来。
- 高效性: 在 Transformer 等模型中提高了性能,被广泛用于高效的前馈神经网络(如 GPT-3 和 BERT 的改进版本中)。
PyTorch 中的实现:
可以通过自定义实现 SwiGLU:
import torch
import torch.nn as nn
class SwiGLU(nn.Module):
def __init__(self, input_dim, hidden_dim):
super(SwiGLU, self).__init__()
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.fc2 = nn.Linear(input_dim, hidden_dim)
self.swish = nn.SiLU() # 使用内置 Swish 函数
def forward(self, x):
return self.fc1(x) * self.swish(self.fc2(x))
# 示例
input_dim = 128
hidden_dim = 256
x = torch.randn(32, input_dim) # Batch size 32, Input dimension 128
swiglu = SwiGLU(input_dim, hidden_dim)
output = swiglu(x)
print(output.shape) # 输出维度: [32, 256]
总结
- Swish 是一种简单的平滑激活函数,PyTorch 内置了等效的实现 (
torch.nn.SiLU或torch.nn.functional.silu
)。 - SwiGLU 是基于 Swish 的门控激活函数,常用于高效前馈神经网络,需要自定义实现。
这两种激活函数在现代深度学习模型(特别是自然语言处理任务)中表现出色。