论文笔记:Gradient Regularized Contrastive Learning for Continual Domain Adaptation

论文地址:https://arxiv.org/abs/2007.12942

代码地址:尚未开源


1 Main Idea(梯度正则化)

1)令梯度不损害源域特征的判别能力,反过来提高模型对目标域的自适应能力;

2)约束梯度不提高在旧目标域上的分类损失,使得模型在旧目标域上的表现得到保持,同时对新目标域有很好的适应能力。

在本文中,作者认为保留源域特征的可分辨能力可以提高模型对不同目标域的适应能力,也即由于源域的标签是已知的,因此从源域中学到的分辨能力可以引导对所有域的适应。多任务学习即源域上的分类损失+域自适应损失,当最小化多任务损失时,源域上分类损失可能会提高,意味着模型在源域上的分辨能力降低。反过来,本文的模型就是要约束源域的分类损失不增加(sourcediscriminative constraint)。

此外,曾经有学者通过在对抗训练框架上添加一步反演(incorporate a replay)。但是这些方法假设连续出现的目标域漂移是服从某种特定模式的,而当遭遇灾难性遗忘(catastrophic forgetting)时,这个猜想不攻自破。与此相反,当模型在新目标域适应时,作则提出令旧目标域的分类损失不增加(target memorization constraint),这个约束可以让模型在适应新目标域同时,不损失在旧目标域上的生成能力。

前者约束对应于源域上分类损失的参数梯度必须是正的,后者约束对应于每一个旧目标域上分类损失的参数梯度必须是正的。(梯度恒为正=>损失不增)。后者用到由clustering《Learning to Cluster Faces via Confidence and Connectivity Estimation》生成的伪标签,这些伪标签是high-quality的,因为旧目标域的特征有较好的分辨能力,而且作者还会把低置信度的样本滤除掉。

2 Problem Formulation

提出了两个指标,衡量模型对于连续出现的目标域的适应能力:average accuracy (ACC);average backward transfer (BWT),补充一个average forward transfer (FWT):

ACC代表模型在完成所有序列自适应任务后,在所有域上的表现。BWT代表模型在适应域时对历史域的影响。负的BWT即适应新域会降低在历史域上的表现。这两个指标越大越好。

3 Methodology

当模型适应第t个目标域时,基准框架是基于在domain-episodic记忆和一个特征库上的对比学习。GRCL的关键创新点是基于联合训练样本(源域、domain-episodic记忆和新目标域)上两个对比形式的约束。Source discriminative constraint可以保持源域样本的辨别能力,从而提高目标域的适应能力。Target memorization constraint解决旧目标域上的灾难性遗忘问题。

3.1 Baseline Framework with Contrastive Learning

《Unsupervised feature learning via non-parametric instance discrimination》《Momentum contrast for unsupervised visual representation learning》《A simple framework for contrastive learning of visual representations》展示了将图像映射到向量空间的能力,相似图像靠的近,反之离得远。作者受此启发,使用对比损失来推动目标域实例朝着与其外观相似的源域实例靠近,将这些特征放到一个特征库(Feature Bank)里并提出联合对比损失。

Feature Bank     ,其中是输入x的一个表征,由计算而得,是一个基于CNN的编码器,是一个在模型适应了t-1个目标域后的MLP映射模块。所有的特征都用标准化。训练的每一代中,mini-batch中的编码特征会用来更新记忆库

Unified Constrastive Loss     为了让模型适应第t个目标域,CNN模块初始化,MLP映射模块初始化,特征库初始化。而且每一个mini-batch中的的比例是固定的。《Representation learning with contrastive predictive coding》中对每个batch计算对比损失函数为:

Q是一个总体特征向量,x代表训练batch的样本。是对于q的positive key,可以定义为样本x储存在对应的特征。除了以外,称为q的negative key,τ固定为0.07。

3.2 Source Discriminative Constraint

对比损失可以通过提取视觉上相似的样本,缩小域间的差距,但可能会损害源域上特征的判别能力。由于源域的标签是已知的,从源域上的学到的知识对所有目标域是有价值的。因此在最小化对比损失时,增加source discriminative constraint约束,使源域上单调不增:

为更新模型的向量,的梯度,受《Gradient episodic memory for continual learning》启发,上式可以写成:

其中是更新向量和源域分类损失的起始梯度的内积。

如图,如果的角度小于π/2,通过最小化对比损失函数不会增加源域上的分类损失。因此作者使用来更新模型参数。故令在靠近的同时满足上式。

3.3 Target Memorization Constraint

主要解决catastrophic forgetting问题,每一段domain-episodic记忆的分类损失不增:

但是如果单独计算每一个,会增加计算负担,因此采用一种近似方法,只用计算中样本块的损失:

3.4 Overall Formulation and Solution of GRCL

为了将梯度正则化插入到对比损失最小化中,将目标函数min修改为最小化是对比损失梯度,是更新网络的梯度。为了最小化对比损失,一定要按照两个约束尽可能地接近

这是一个二次规划问题,在对偶空间中解算,使得只有两个变量就可以实现最小化QP问题:

其中而且我们删除了的常数项。一旦上式的解找出来了,我们可以通过解上述二次规划问题。

Discussion 以前解决域自适应或持续域自适应的paper都是用多任务学习的方法来实现的,损失函数写成。但作者提出的方法和多任务学习不同:1)GRGL保证参数的更新不会损害源域和旧目标域的分类损失。反过来,多任务损失只会学习总体损失的最小化,而没有source discriminative constraint或target memorization constraint约束;2)多任务学习的参数取舍和的不一样,前者是人工设定的,后者是由计算出来的,并在每一代是自动适配的。

  • 0
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值