一、持续学习
1. 定义与背景
持续学习是指机器学习模型在面对一系列动态任务或数据流时,能够逐步学习新知识,同时保留对旧任务的记忆,避免“灾难性遗忘”。它模仿了人类终身学习的能力,目标是构建一个在动态环境中持续适应、积累知识的模型。
为什么重要?
- 现实需求:现实世界的数据和任务是动态的(如机器人学习新技能、推荐系统适应用户行为变化)。
- 资源限制:存储所有历史数据或频繁重新训练模型不现实。
- 隐私保护:旧任务数据可能因隐私问题无法重新访问。
关键挑战:
- 灾难性遗忘:学习新任务时,模型参数更新可能覆盖旧任务知识。
- 任务干扰:新任务学习可能干扰旧任务性能。
- 可扩展性:任务数量增加时,模型需要高效管理知识。
- 数据访问限制:旧任务数据可能不可用。
- 计算效率:在有限资源下实现持续更新。
2. 持续学习的分类
持续学习根据任务和数据特性分为以下几种场景:
- 任务增量学习:
- 模型依次学习多个明确的任务,每个任务有独立的标签空间。
- 通常提供任务ID以区分任务。
- 示例:先学习手写数字分类(MNIST),再学习图像分类(CIFAR-10)。
- 新增例子:一个模型先学习垃圾邮件分类(任务1:区分垃圾邮件和正常邮件),再学习新闻文章分类(任务2:区分政治、经济、体育新闻),每个任务使用独立的分类器,但共享部分特征提取层。
- 类增量学习:
- 模型逐步学习新类别,但需要对所有类别进行统一分类,无任务ID。
- 示例:先学习10种动物分类,再学习10种新动物,模型需区分所有20种。
- 新增例子:一个图像分类模型先学习识别猫、狗、鸟等10种动物(任务1),再学习识别鱼、爬行动物等10种新动物(任务2),最终需在无任务ID的情况下区分所有20种动物。
- 数据增量学习:
- 数据分布随时间变化,但任务目标不变。
- 示例:推荐系统根据用户行为变化更新模型。
- 新增例子:一个情感分析模型根据社交媒体用户评论的语言风格变化(如新流行语的出现)持续更新,保持对正面、负面情感的分类能力。
- 在线学习:
- 数据以流式方式到达,模型实时更新。
- 示例:金融预测模型根据实时市场数据调整。
- 新增例子:一个实时文本分类系统根据用户输入的推文流,持续更新以区分正面、负面和中性情感,适应不断变化的语言表达。
3. 持续学习的核心方法
持续学习的方法主要分为三大类:正则化方法、动态架构方法和基于记忆的方法。以下详细介绍每种方法的核心思想、数学原理和代码实现。
(1)正则化方法
核心思想:通过在损失函数中添加正则化项,限制模型参数在学习新任务时偏离旧任务的参数,保护重要权重。
代表方法:弹性权重巩固(EWC)
- 原理:EWC通过Fisher信息矩阵估计每个参数对旧任务的重要性,对重要参数施加强约束。
- 数学公式:损失函数为 L = L new + λ ∑ i F i ( θ i − θ i ∗ ) 2 L = L_{\text{new}} + \lambda \sum_i F_i (\theta_i - \theta_i^*)^2 L=Lnew+λi∑Fi(θi−θi∗)2,其中 L new L_{\text{new}} Lnew 为新任务的损失, F i F_i Fi 为参数 θ i \theta_i θi 对旧任务的重要性(Fisher信息矩阵的对角元素), θ i ∗ \theta_i^* θi∗ 为旧任务训练后的参数值, λ \lambda λ 为正则化强度。
- 实现步骤:
- 训练旧任务,保存模型参数 θ ∗ \theta^* θ∗ 和Fisher信息矩阵。
- 学习新任务时,添加正则化项到损失函数。
- 新增例子:在文本分类任务中,先训练一个模型识别正面和负面评论(任务1),使用EWC保存重要参数;再学习中性和讽刺评论分类(任务2),通过正则化项保护任务1的性能,防止遗忘正面和负面评论的分类能力。
代码示例(EWC实现):
以下是一个基于PyTorch的EWC简单实现,假设任务是MNIST数据集的类增量学习。
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
import numpy as np
# 定义简单神经网络
class SimpleNet(nn.Module):
def __init__(self):
super(SimpleNet, self).__init__()
self.fc1 = nn.Linear(28 * 28, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = x.view(-1, 28 * 28)
x =<