Learning Without Forgetting 笔记及实现

LWF简介

LWF是结合知识蒸馏(KD)避免灾难性遗忘的经典持续学习方法。本质上是通过旧网络指导的输出对在新任务训练的网络参数进行平衡,从而得到在新旧任务网络上都表现较好的性能。

方法对比

compare
a.从头开始训练
b.微调:在旧任务的网络基础上以较小的学习率学习新任务 另一种意义上的initialization?
c.联合训练:使用所有任务的数据一起训练
d.特征提取:将旧任务的参数固定作为特征提取器,添加新的层训练新任务

LWF算法流程

algorithm
θ s \theta_s θs为在old task上pretrained网络CNN的共享参数
θ o \theta_o θo为每个old task的特定参数(可理解为网络的i最后的classifier head)
( X n , Y n ) (X_n,Y_n) (Xn,Yn) new task的数据

初始化:
1.将新数据 ( X n , Y n ) (X_n,Y_n) (Xn,Yn) 输入在旧任务pretrained网络中得到一组respond Y o Y_o Yo
2.将new task对应的classifier head参数随机初始化(加快训练的常见手段)

训练:
Y o ^ \hat{Y_o} Yo^ 为待训练网络CNN 对应old task的输出,最开始 θ o \theta_o θo= θ o ^ \hat{\theta_o} θo^ , θ s \theta_s θs= θ s ^ \hat{\theta_s} θs^
Y n ^ \hat{Y_n} Yn^ 为待训练网络对应new task的输出,最开始 θ n \theta_n θn= θ n ^ \hat{\theta_n} θn^ , θ s \theta_s θs= θ s ^ \hat{\theta_s} θs^
优化目标为
θ s ∗ , θ o ∗ , θ n ∗ ← argmin ⁡ θ ^ s , θ ^ o , θ ^ n ( λ o L o l d ( Y o , Y ^ o ) + L n e w ( Y n , Y ^ n ) + R ( θ ^ s , θ ^ o , θ ^ n ) ) \theta_{s}^{*}, \theta_{o}^{*}, \theta_{n}^{*} \leftarrow \underset{\hat{\theta}_{s}, \hat{\theta}_{o}, \hat{\theta}_{n}}{\operatorname{argmin}}\left(\lambda_{o} \mathcal{L}_{o l d}\left(Y_{o}, \hat{Y}_{o}\right)+\mathcal{L}_{n e w}\left(Y_{n}, \hat{Y}_{n}\right)+\mathcal{R}\left(\hat{\theta}_{s}, \hat{\theta}_{o}, \hat{\theta}_{n}\right)\right) θs,θo,θnθ^s,θ^o,θ^nargmin(λoLold(Yo,Y^o)+Lnew(Yn,Y^n)+R(θ^s,θ^o,θ^n))
第一项可以理解为old task的子优化目标,第二项为new task的优化目标,第三项为正则化项。
可以发现整个训练过程和joint training很相似,但是最大的不同是LWF没有用到old task data,而是巧妙地用KD损失去平衡old task的性能。至于KD则体现在以下公式:
L o l d ( y o , y ^ o ) = − H ( y o ′ , y ^ o ′ ) = − ∑ i = 1 l y o ′ ( i ) log ⁡ y ^ o ′ ( i ) \begin{aligned} \mathcal{L}_{o l d}\left(\mathbf{y}_{o}, \hat{\mathbf{y}}_{o}\right) &=-H\left(\mathbf{y}_{o}^{\prime}, \hat{\mathbf{y}}_{o}^{\prime}\right) \\ &=-\sum_{i=1}^{l} y_{o}^{\prime(i)} \log \hat{y}_{o}^{\prime(i)} \end{aligned} Lold(yo,y^o)=H(yo,y^o)=i=1lyo(i)logy^o(i)
l l l 是label的数量,而 y ^ o ′ ( i ) \hat{y}_{o}^{\prime(i)} y^o(i) y o ′ ( i ) y_{o}^{\prime(i)} yo(i) y ^ o ( i ) \hat{y}_{o}^{(i)} y^o(i) y o ( i ) {y}_{o}^{(i)} yo(i) 的修正版本,也就是这里体现了KD的概念, y o ′ ( i ) y_{o}^{\prime(i)} yo(i) 是soft target,而 y ^ o ′ ( i ) \hat{y}_{o}^{\prime(i)} y^o(i) 为网络预测概率值。

y o ′ ( i ) = ( y o ( i ) ) 1 / T ∑ j ( y o ( j ) ) 1 / T , y ^ o ′ ( i ) = ( y ^ o ( i ) ) 1 / T ∑ j ( y ^ o ( j ) ) 1 / T y_{o}^{\prime(i)}=\frac{\left(y_{o}^{(i)}\right)^{1 / T}}{\sum_{j}\left(y_{o}^{(j)}\right)^{1 / T}}, \quad \hat{y}_{o}^{\prime(i)}=\frac{\left(\hat{y}_{o}^{(i)}\right)^{1 / T}}{\sum_{j}\left(\hat{y}_{o}^{(j)}\right)^{1 / T}} yo(i)=j(yo(j))1/T(yo(i))1/T,y^o(i)=j(y^o(j))1/T(y^o(i))1/T
所以网络在训练时,第一部分的loss使得网络的输出概率值一定程度上贴近old task

总结

LWF其实质上是结合了KD和微调,优势在于训练相比joint training更快,且不需要访问先前的数据。但连续学习多个任务仍然避免不了灾难性遗忘

实现

以下是基于pytorch的简单复现,废话不多说贴上code

# 准备数据集
n_classes = 10 # split_mnist数据集一共10类
n_tasks = 5
per_classes_task = int(n_classes / n_tasks)
split_mnist = SplitMNIST(n_experiences=n_tasks, seed=0, 
                         return_task_id=True, shuffle=False)

train_dataset = split_mnist.train_stream[0].dataset
test_dataset = split_mnist.test_stream[0].dataset

train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(
    test_dataset, batch_size=1000, shuffle=True)

训练旧任务

def kaiming_normal_init(m):
    if isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
    elif isinstance(m, nn.Linear):
        nn.init.kaiming_normal_(m.weight, nonlinearity='sigmoid')
        

def train(epoch, model, optimizer, criterion):
    print('\nEpoch: %d' % epoch)
    model.train()
    train_loss = 0.0
    correct = 0
    total = 0
    for batch_id, (x, y, t) in enumerate(train_loader):
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        y_pred = model(x)
        loss = criterion(y_pred, y)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, y_pred = y_pred.max(1)
        total += len(y)
        correct += y_pred.eq(y).sum().item()

        progress_bar(batch_id, len(train_loader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                     % (train_loss / (batch_id + 1), 100. * correct / total, correct, total))

    return train_loss / (batch_id + 1)


def test(epoch, model, criterion):
    global best_acc
    model.eval()
    test_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_id, (x, y, t) in enumerate(test_loader):
            x, y = x.to(device), y.to(device)
            y_pred = model(x)
            loss = criterion(y_pred, y)

            test_loss = loss.item()
            _, y_pred = y_pred.max(1)
            total += len(y)
            correct += y_pred.eq(y).sum().item()

            progress_bar(batch_id, len(test_loader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                         % (test_loss / (batch_id + 1), 100. * correct / total, correct, total))

    acc = 100. * correct / total
    if acc > best_acc:
        print('Saving..')
        state = {
            'model': model.state_dict(),
            'acc': acc,
            'epoch': epoch,
        }
        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        torch.save(state, './checkpoint/ckpt_mnist.pth')
        best_acc = acc
    return acc
   
# 训练和测试
device = 'cuda' if torch.cuda.is_available() else 'cpu'
epochs = 10
best_acc = 0.0
lr = 0.01
pre_model = SimpleMLP(num_classes=per_classes_task, hidden_size=256).to(device)
print(pre_model)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(pre_model.parameters(), lr=lr,
                          momentum=0.9, weight_decay=5e-4)
for epoch in range(epochs):
    train(epoch, pre_model, optimizer, criterion)
    test(epoch, pre_model, criterion)

LWF

split_mnist = SplitMNIST(n_experiences=n_classes, seed=0, 
                         return_task_id=True, shuffle=False)
# 取第2个2分类任务
train_dataset = split_mnist.train_stream[1].dataset
test_dataset = split_mnist.test_stream[1].dataset
# 取第1个2分类任务测试LWF在旧任务上的性能
val_dataset = split_mnist.test_stream[0].dataset
train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(
        test_dataset, batch_size=1000, shuffle=True)
val_loader = torch.utils.data.DataLoader(
        val_dataset, batch_size=1000, shuffle=True)
net_new = SimpleMLP(num_classes=per_classes_task, hidden_size=256).to(device)
net_old = SimpleMLP(num_classes=per_classes_task, hidden_size=256).to(device)
oor = torch.load('checkpoint/ckpt_mnist.pth')
net_new.load_state_dict(oor['model'])
net_old.load_state_dict(oor['model'])
incremental_class = per_classes_task
# 获取前一个任务模型的分类头数量
in_features = net_old.classifier.in_features
out_features = net_old.classifier.out_features
# 提取分类头中参数
weight = net_old.classifier.weight.data
bias = net_old.classifier.bias.data
# 新头数量
new_out_features = incremental_class + out_features 
# 构建新分类器
new_fc = nn.Linear(in_features, new_out_features)
kaiming_normal_init(new_fc.weight)
# 新任务模型的前两个头被替换,剩余头用来学习新类
new_fc.weight.data[:out_features] = weight
new_fc.bias.data[:out_features] = bias
net_new.classifier = new_fc
net_new = net_new.to(device)
print('new head numbers:', net_new.classifier.out_features)
# 确保前一个任务模型不参与反向传播
for param in net_old.parameters():
    param.requires_grad = False

改变训练,测试方法

def train(alpha, T, epoch):
    print('\nEpoch: %d' % epoch)
    net_new.eval()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (x, y, t) in enumerate(train_loader):
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        outputs = net_new(x)
        soft_y = net_old(x)
        # 新类的Loss
        loss1 = criterion(outputs, y)
        outputs_S = F.softmax(outputs[:, :out_features] / T, dim=1)
        outputs_T = F.softmax(soft_y[:, :out_features] / T, dim=1)
        loss2 = outputs_T.mul(-1 * torch.log(outputs_S))
        loss2 = loss2.sum(1)
        loss2 = loss2.mean() * T * T
        # loss = loss1 * alpha + loss2 * (1 - alpha)
        loss = loss1 + alpha * loss2
        loss.backward(retain_graph=True)
        # loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, y_pred = outputs.max(1)
        total += len(y)
        correct += y_pred.eq(y).sum().item()

        progress_bar(batch_idx, len(train_loader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                     % (train_loss / (batch_idx + 1), 100. * correct / total, correct, total))

    return train_loss / (batch_idx + 1)

def test(alpha, T, epoch):
    global best_acc
    net_new.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets, t) in enumerate(test_loader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net_new(inputs)
            soft_target = net_old(inputs)
            loss1 = criterion(outputs, targets)
            outputs_S = F.softmax(outputs[:, :out_features] / T, dim=1)
            outputs_T = F.softmax(soft_target[:, :out_features] / T, dim=1)
            loss2 = outputs_T.mul(-1 * torch.log(outputs_S))
            loss2 = loss2.sum(1)
            loss2 = loss2.mean() * T * T
            loss = loss1 * alpha + loss2 * (1 - alpha)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += len(targets)
            correct += predicted.eq(targets).sum().item()
            progress_bar(batch_idx, len(test_loader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                         % (test_loss / (batch_idx + 1), 100. * correct / total, correct, total)) 
    acc = 100. * correct / total
    if acc > best_acc:
        print('Saving..')
        state = {
            'model': net_new.state_dict(),
            'acc': acc,
            'epoch': epoch,
        }
        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        torch.save(state, './checkpoint/LWF_ckpt.pth')
        best_acc = acc
    return acc


def val(epoch): # 用于测试旧任务
    net_new.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets, t) in enumerate(val_loader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net_new(inputs)
            _, predicted_old = outputs.max(1)
            total += len(targets)
            correct += predicted_old.eq(targets).sum().item()
            progress_bar(batch_idx, len(val_loader), 'Acc: %.3f%% (%d/%d)'
                         % (100. * correct / total, correct, total))
    return 100. * correct / total

训练和测试

# 简单实现,超参非最佳参数
T = 2
alpha = 0.5  # 
criterion = nn.CrossEntropyLoss()
best_acc = 0.0

optimizer = optim.SGD(filter(lambda p: p.requires_grad, net_new.parameters()), lr=0.01,
                          momentum=0.9, weight_decay=5e-4)
for epoch in range(epochs):
    train_loss = train(alpha, T, epoch)
    acc_new = test(alpha, T, epoch)
    acc_old = val(epoch) 
torch.save(net_new.state_dict(), 'model.pth')

论文地址

  • 15
    点赞
  • 47
    收藏
    觉得还不错? 一键收藏
  • 10
    评论
learning without forgetting是指在进行连续学习任务时,保持之前所学习知识的不被遗忘。为了实现learning without forgetting,可以使用PyTorch这一深度学习框架。 在PyTorch中,可以使用增量学习(incremental learning)的方法。具体步骤如下: 1. 定义初始模型:首先,定义一个初始模型,用于解决第一个学习任务。可以使用PyTorch中的Module类来创建模型,并选择适当的网络结构。 2. 学习第一个任务:使用第一个任务的数据集对模型进行训练。可以使用PyTorch提供的DataLoader类来加载数据集,使用优化器(如Adam或SGD)和损失函数(如交叉熵损失)对模型进行训练。 3. 保存模型参数:在完成第一个任务的训练后,将模型的参数保存起来。可以使用torch.save()函数将参数保存到磁盘上的文件中。 4. 准备新任务:准备新的数据集和标签,用于学习新的任务。可以使用相同的网络结构或者更改网络结构,根据新的任务要求进行适当的调整。 5. 加载之前的模型参数:在开始新的任务训练之前,使用torch.load()函数加载之前保存的模型参数。 6. 设置学习率:由于新的任务可能与之前的任务有不同的难度或重要性,可以设置不同的学习率来适应新任务的特点。可以使用PyTorch中的scheduler类或手动调整学习率。 7. 学习新任务:使用新的数据集对模型进行更新训练。可以使用先前定义的优化器和损失函数,使用torch.nn.Module的train()方法进行训练。 通过以上步骤,可以在PyTorch中实现learning without forgetting。重要的是保存和加载已训练模型参数,并根据新任务的要求进行适当的调整。同时,可以根据需要设置学习率等超参数,以更好地适应不同任务的特点。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 10
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值