持续学习EWC代码实现

Overcoming catastrophic forgetting in neural networks

论文地址:EWC论文
论文代码:EWC代码,该代码包含大部分持续学习算法的代码
论文中公式推导论文:Elastic Weight Consolidation (EWC): Nuts and Bolts
关于论文的代码和公式推导CSDN上有几篇博客写的也挺不错,但是关于公式推导中的拉普拉斯变化,博客观点不统一,故本篇博客公式推导主要参考Elastic Weight Consolidation (EWC): Nuts and Bolts这篇论文。

一、持续学习简单介绍

持续学习指的是模型在完成新任务的同时不忘记旧任务如何完成的。由于神经网络存在灾难性遗忘,导致很难进行持续学习。目前,《A Continual Learning Survey: Defying Forgetting in Classification Tasks》这篇关于持续学习的综述将持续学习方法主要分为三类:
1.Replay Methods
2.Regularization-Based Methods
3.Parameter Isolation Methods
在这里插入图片描述

EWC属于第二类,基本思想是针对单个任务的神经网络中,有一些网络参数对完成该任务有着重要影响,为了保持对该任务的性能,应当让这些重要参数保持不变或者变化很小。

二、EWC主要思想

EWC主要从概率角度出发,推导出重要度矩阵用来度量网络参数对旧任务的重要程度并得到重要度矩阵即Fisher信息矩阵,为了让这些对旧任务重要的参数在完成新任务时变化不大,在训练新任务时添加了L2正则项并结合重要度矩阵来对完成旧任务重要的网络参数进行约束。

三、EWC公式推导

以下为本人写的公式推导过程,如有错误,尽情批评指正
在这里插入图片描述
在这里插入图片描述

四、EWC代码实现

该代码为针对多任务的EWC实现,和两个任务的EWC实现不同点在于Fisher信息矩阵的处理,多任务的Fisher信息矩阵获得代码如下

 # Fisher ops
        if t>0: # t表示任务序号,从零开始
            fisher_old={}
            for n,_ in self.model.named_parameters():
                fisher_old[n]=self.fisher[n].clone()
        self.fisher=utils.fisher_matrix_diag(t,xtrain,ytrain,self.model,self.criterion)
        if t>0:
            # Watch out! We do not want to keep t models (or fisher diagonals) in memory, therefore we have to merge fisher diagonals
            for n,_ in self.model.named_parameters():
                self.fisher[n]=(self.fisher[n]+fisher_old[n]*t)/(t+1)       # Checked: it is better than the other option,当t=0时,self.fisher=None
                #self.fisher[n]=0.5*(self.fisher[n]+fisher_old[n])

Fisher的编程实现为(参考链接:2.如何计算Fisher信息矩阵
在这里插入图片描述
关于Fisher信息矩阵的计算函数如下:

def fisher_matrix_diag(t,x,y,model,criterion,sbatch=20):
    # Init
    fisher={}
    for n,p in model.named_parameters():
        fisher[n]=0*p.data
    # Compute
    model.train()
    for i in tqdm(range(0,x.size(0),sbatch),desc='Fisher diagonal',ncols=100,ascii=True):
        b=torch.LongTensor(np.arange(i,np.min([i+sbatch,x.size(0)]))).cuda()
        images=torch.autograd.Variable(x[b],volatile=False)
        target=torch.autograd.Variable(y[b],volatile=False)
        # Forward and backward
        model.zero_grad()
        outputs=model.forward(images)
        loss=criterion(t,outputs[t],target)
        loss.backward()
        # Get gradients
        for n,p in model.named_parameters():
            if p.grad is not None:
                fisher[n]+=sbatch*p.grad.data.pow(2)
    # Mean
    for n,_ in model.named_parameters():
        fisher[n]=fisher[n]/x.size(0)
        fisher[n]=torch.autograd.Variable(fisher[n],requires_grad=False)
    return fisher

关于EWC的损失函数实现代码如下:

def criterion(self,t,output,targets):
        # Regularization for all previous tasks
        loss_reg=0
        if t>0:
            for (name,param),(_,param_old) in zip(self.model.named_parameters(),self.model_old.named_parameters()):
                loss_reg+=torch.sum(self.fisher[name]*(param_old-param).pow(2))/2 # EWC的损失函数的正则化部分

        return self.ce(output,targets)+self.lamb*loss_reg
  • 15
    点赞
  • 75
    收藏
    觉得还不错? 一键收藏
  • 10
    评论
### 回答1: PyTorch是一个开源的深度学习框架,它为持续学习提供了很好的支持。持续学习是指通过不断地学习新的数据、调整模型和继续训练,从而实现模型的优化和更新。下面是使用PyTorch实现持续学习的一些关键步骤: 1. 数据处理:将新的数据加载到PyTorch中,并进行预处理操作,例如数据标准化、数据增强等。可以使用PyTorch中的数据加载器(DataLoader)和数据预处理工具(transform)加快处理过程。 2. 模型加载:加载已经训练好的模型参数,可以使用PyTorch的torch.load()函数加载先前训练模型的参数。 3. 模型调整:根据新的数据特点,对模型进行微调或调整。可以使用PyTorch提供的灵活的模型定义和修改方式,例如修改模型的层结构、修改激活函数等。 4. 优化器选择:选择合适的优化器,例如Adam、SGD等,以在持续学习过程中调整模型的权重。 5. 训练过程:使用新的数据对模型进行训练,并反复迭代调整模型。可以使用PyTorch提供的自动微分功能,加快梯度计算和模型更新过程。 6. 模型保存:在每次训练迭代结束后,保存模型的最新参数。可以使用PyTorch的torch.save()函数保存模型参数。 7. 持续学习:重复上述步骤,对新的数据进行处理、模型调整和训练过程,以实现模型的持续学习。 通过上述步骤,使用PyTorch可以实现持续学习的过程。凭借其灵活性和强大的计算能力,PyTorch能够满足各种深度学习模型对于持续学习的需求,并为模型的优化提供支持。同时,PyTorch还提供了丰富的工具和函数,帮助开发者更高效地实现持续学习。 ### 回答2: PyTorch是一个开源的机器学习框架,它提供了丰富的工具和功能来支持持续学习持续学习指的是通过新数据的输入,持续改进和更新现有的模型,以适应不断变化的环境和任务。 PyTorch提供了一个灵活和可扩展的架构,使得持续学习变得更加容易。以下是在PyTorch中实现持续学习的一些关键步骤: 1. 数据管理:持续学习需要处理不断变化的数据。PyTorch中的DataLoader和Dataset类可以帮助加载和管理数据集。您可以创建一个数据加载器来批量加载新的数据集,并将其与之前的数据集合并。 2. 模型更新:当有新的数据到达时,您可以使用PyTorch的优化器来更新模型的参数,以适应新的数据。您可以使用反向传播算法计算损失,并调用优化器的`step`函数来更新模型的参数。 3. 继续训练:持续学习意味着在之前训练的基础上继续学习。您可以加载之前训练保存的模型,并在新的数据上进行训练。在PyTorch中,您可以使用`torch.load`函数加载之前训练的模型,并通过调用`train`函数来继续训练。 4. 模型评估:持续学习需要在新的数据上进行模型评估,以评估其性能和适应能力。您可以使用PyTorch中的评估函数和指标来评估模型的准确性和效果。 5. 灵活性:PyTorch的灵活性使得您可以自定义和调整模型结构,以适应不同的任务和数据。您可以根据新的数据特点调整模型的层次、结构和参数。 总之,PyTorch为持续学习提供了丰富的功能和易用的工具。通过管理数据、更新模型、继续训练和模型评估,您可以在PyTorch中有效地实现持续学习。 ### 回答3: 在PyTorch中实现持续学习的关键是使用动态图的特性和灵活的模型更新方法。 首先,PyTorch的动态图机制允许我们在运行时构建和修改模型图,这使得持续学习更加容易。我们可以将新的数据集添加到已经训练的模型上,并通过反向传播来更新模型的权重。这样,我们可以通过在已有模型上继续训练来逐步适应新的数据,而无需重新训练整个模型。 其次,持续学习的另一个重要问题是防止旧知识的遗忘。为了解决这个问题,我们可以使用增量学习方法,如Elastic Weight Consolidation(EWC)或Online Deep Learning(ODL)。这些方法通过使用正则化项或定义损失函数来限制新训练数据对旧知识的影响,从而保护旧有的模型参数。 此外,我们还可以使用PyTorch提供的模型保存和加载功能来实现持续学习。我们可以定期保存模型的参数和优化器状态,以便在需要时恢复模型,并继续训练过程。通过这种方式,我们可以持续积累更多的数据和知识,而无需从头开始每次都重新训练模型。 总的来说,PyTorch提供了灵活的动态图和丰富的工具,使得实现持续学习变得简单。我们可以通过动态修改模型图、使用增量学习方法来应对新数据和旧知识的挑战,并使用模型保存和加载功能来持续积累数据和知识。这些方法的组合可以帮助我们在PyTorch中实现高效的持续学习
评论 10
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值