常见激活函数及其应用
在深度学习中,激活函数是神经网络中的重要组成部分。它们引入非线性,使模型能够学习复杂的模式。本文将介绍几种常见的激活函数,并提供它们的图示、优缺点和适用场景。
激活函数图示
下面的图展示了几种常见激活函数的形状:
激活函数概述
激活函数 | 公式 | 优点 | 缺点 | 适用场景 |
---|---|---|---|---|
Sigmoid | f ( x ) = 1 1 + e − x f(x) = \frac{1}{1 + e^{-x}} f(x)=1+e−x1 | 输出可以解释为概率,适用于二分类问题。 | 在极端值时梯度消失,导致学习缓慢。 | 常用于输出层,尤其是二分类问题。 |
Tanh | f ( x ) = tanh ( x ) = e x − e − x e x + e − x f(x) = \tanh(x) = \frac{e^x - e^{-x}}{e^x + e^{-x}} f(x)=tanh(x)=ex+e−xex−e−x | 输出均值为 0,通常能更快收敛。 | 仍然存在梯度消失问题。 | 常用于隐藏层,适合需要输出为负值和正值的场景。 |
ReLU | f ( x ) = max ( 0 , x ) f(x) = \max(0, x) f(x)=max(0,x) | 计算简单,收敛速度快,能够有效缓解梯度消失问题。 | 在负值区域的梯度为 0,可能导致“死亡神经元”现象。 | 广泛用于隐藏层,尤其是深度神经网络。 |
Leaky ReLU | f ( x ) = { x if x > 0 α x if x ≤ 0 f(x) = \begin{cases} x & \text{if } x > 0 \\ \alpha x & \text{if } x \leq 0 \end{cases} f(x)={xαxif x>0if x≤0 | 解决了 ReLU 的“死亡神经元”问题。 | 仍然可能出现负值区域的梯度消失。 | 适合需要避免“死亡神经元”的场景,常用于隐藏层。 |
Softmax | f ( x i ) = e x i ∑ j e x j f(x_i) = \frac{e^{x_i}}{\sum_{j} e^{x_j}} f(xi)=∑jexjexi | 输出为概率分布,适合多分类问题。 | 对异常值敏感,可能导致数值不稳定。 | 常用于输出层,特别是多分类问题。 |
Swish | f ( x ) = x ⋅ sigmoid ( x ) f(x) = x \cdot \text{sigmoid}(x) f(x)=x⋅sigmoid(x) | 在某些任务中,Swish 函数表现优于 ReLU。 | 计算复杂度高于 ReLU。 | 适合深度学习模型,尤其是在需要更复杂非线性的场景。 |
GELU | f ( x ) = x ⋅ Φ ( x ) = x ⋅ 1 2 ( 1 + tanh ( 2 / π ( x + 0.044715 x 3 ) 1 ) ) f(x) = x \cdot \Phi(x) = x \cdot \frac{1}{2} \left(1 + \text{tanh}\left(\frac{\sqrt{2/\pi}(x + 0.044715x^3)}{1}\right)\right) f(x)=x⋅Φ(x)=x⋅21(1+tanh(12/π(x+0.044715x3))) | 在一些深度学习模型(如 BERT 和 GPT)中表现良好。 | 实现相对复杂,计算开销较大。 | 适合高级深度学习模型,尤其是自然语言处理任务。 |
总结
选择合适的激活函数对于神经网络的性能和收敛速度至关重要。一般来说,ReLU 和其变体(如 Leaky ReLU 和 PReLU)在隐藏层中非常流行,而 Sigmoid 和 Softmax 通常用于输出层。根据具体任务和数据的特性,选择合适的激活函数可以帮助提高模型的表现。
import numpy as np
import matplotlib.pyplot as plt
# 定义激活函数
def sigmoid(x):
return 1 / (1 + np.exp(-x))
def tanh(x):
return np.tanh(x)
def relu(x):
return np.maximum(0, x)
def leaky_relu(x, alpha=0.01):
return np.where(x > 0, x, alpha * x)
def softmax(x):
exp_x = np.exp(x - np.max(x)) # 防止溢出
return exp_x / np.sum(exp_x)
def swish(x):
return x * sigmoid(x)
def gelu(x):
return 0.5 * x * (1 + np.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * x**3)))
# 定义输入范围
x = np.linspace(-5, 5, 100)
# 绘制激活函数
plt.figure(figsize=(12, 8))
plt.subplot(3, 3, 1)
plt.plot(x, sigmoid(x), label='Sigmoid')
plt.title('Sigmoid')
plt.grid()
plt.legend()
plt.subplot(3, 3, 2)
plt.plot(x, tanh(x), label='Tanh', color='orange')
plt.title('Tanh')
plt.grid()
plt.legend()
plt.subplot(3, 3, 3)
plt.plot(x, relu(x), label='ReLU', color='green')
plt.title('ReLU')
plt.grid()
plt.legend()
plt.subplot(3, 3, 4)
plt.plot(x, leaky_relu(x), label='Leaky ReLU', color='red')
plt.title('Leaky ReLU')
plt.grid()
plt.legend()
plt.subplot(3, 3, 5)
plt.plot(x, swish(x), label='Swish', color='purple')
plt.title('Swish')
plt.grid()
plt.legend()
plt.subplot(3, 3, 6)
plt.plot(x, gelu(x), label='GELU', color='brown')
plt.title('GELU')
plt.grid()
plt.legend()
plt.subplot(3, 3, 7)
plt.plot(x, softmax(x), label='Softmax', color='cyan')
plt.title('Softmax (example for x=[-2, 0, 2])')
plt.grid()
plt.legend()
plt.tight_layout()
plt.show()