一、简介
AMS可以在无监督和在线学习中计算网络参数的重要性。给与新数据可以计算出网络参数的特征重要性,基于模型数据的L2范数的平方,其参数的梯度反应新数据预测的敏感性,将其作为权重,让其保守变化,提高模型的泛化能力和减少模型的复杂度。
首次将基于未标记数据的参数重要性调整网络需要(不要)忘记的内容的能力,这可能会因测试条件而异。
二、重要贡献
- 提出 AMS
- 我们展示了MAS的局部变体是如何与Hebbian学习计划联系在一起的
- 方法达到了 SOTA, 方法同样适用于对象识别和预测(输出为embedding而不是softmax)
3. MAS
算法
3.1 参数重要性计算
在MAS
中损失函数如下, 模型在学习任务B之前学习任务A。
L B = L ( θ ) + ∑ i λ 2 Ω i ( θ i − θ A , i ∗ ) 2 \mathcal{L}_B = \mathcal{L}(\theta) + \sum_{i} \frac{\lambda}{2} \Omega_i (\theta_{i} - \theta_{A,i}^{*})^2 LB=L(θ)+i∑2λΩi(θi−θA,i∗)2
相对EWC
来说, 在损失函数中
F
i
F_i
Fi 被
Ω
i
\Omega_i
Ωi 替代,
Ω
i
\Omega_i
Ωi 计算方法如下
Ω i = ∣ ∣ ∂ ℓ 2 2 ( M ( x k ; θ ) ) ∂ θ i ∣ ∣ \Omega_i = || \frac{\partial \ell_2^2(M(x_k; \theta))}{\partial \theta_i} || Ωi=∣∣∂θi∂ℓ22(M(xk;θ))∣∣
x
k
x_k
xk 是之前任务中的样本数据。所以
Ω
\Omega
Ω是所学习的网络输出的平方L2范数的梯度。 目的:为了在其梯度中寻找对新任务预测敏感的参数,让其保守变化。有效防止与先前任务相关的重要知识被覆盖.
论文中提出的方法是通过从模型的每一层获取平方L2范数输出的局部版本。 下面实现全局版本, 仅用通过模型的最后一层获取输出。
3.2 python实现
具体的应用案例可以看笔者的github: AMS_Train.ipynb
class mas(object):
def __init__(self, model, dataloader, device, prev_guards=[None]):
self.model = model
self.dataloader = dataloader
# 提取模型全部参数
self.params = {n: p for n, p in self.model.named_parameters() if p.requires_grad}
# 参数初始化
self.p_old = {}
self.device = device
# 保存之前的 guards
self.previous_guards_list = prev_guards
# 生成 Omega(Ω) 矩阵
self._precision_matrices = self._calculate_importance()
for n, p in self.params.items():
# 保留原始数据 - 保存为不可导
self.p_old[n] = p.clone().detach()
def _calculate_importance(self):
out = {}
# 初始化 Omega(Ω) 矩阵(全部填充0)并加上之前的 guards
for n, p in self.params.items():
out[n] = p.clone().detach().fill_(0)
for prev_guard in self.previous_guards_list:
if prev_guard:
out[n] += prev_guard[n]
self.model.eval()
if self.dataloader is not None:
number_data = len(self.dataloader)
for x, y in self.dataloader:
self.model.zero_grad()
x, y = x.to(self.device), y.to(self.device)
pred = self.model(x)
# 生成 Omega(Ω) 矩阵.
# 网络输出 L2范数平方的梯度
loss = torch.mean(torch.sum(pred ** 2, axis=1))
loss.backward()
for n, p in self.model.named_parameters():
out[n].data += torch.sqrt(p.grad.data ** 2) / number_data
out = {n: p for n, p in out.items()}
return out
def penalty(self, model: nn.Module):
loss = 0
for n, p in model.named_parameters():
# 最终的正则项 = Omega(Ω)权重 * 权重变化平方((p - self.p_old[n]) ** 2)
_loss = self._precision_matrices[n] * (p - self.p_old[n]) ** 2
loss += _loss.sum()
return loss
def update(self, model):
return