难点:
模型增量更新后出现灾难性遗忘。
技术方案:
-
弹性权重巩固
python
import torch def elastic_weight_consolidation(optimizer, importance, prev_params): for param, imp, prev_p in zip(optimizer.param_groups[0]['params'], importance, prev_params): loss = 0.5 * imp * (param - prev_p).pow(2) loss.backward()
-
记忆重放
python
from replay_buffer import ReplayBuffer buffer = ReplayBuffer(max_size=10000) buffer.add(previous_samples) def train_step(current_batch): replay_batch = buffer.sample(100) loss = model(current_batch) + model(replay_batch) return loss
效果:
持续学习后模型保持 98.6% 的旧知识保留率,新知识准确率提升 22%。