CoTTA:连续的测试时域自适应方法

f3a92b710c421f6c2fc9e6ed78ba20e5.png

文章信息

896deee8bc69fa92889ce15d06130600.png

论文题目为《Continual Test-Time Domain Adaptation》,该文于2022年发表于Conference on Computer Vision and Pattern Recognition (CVPR)会议上。文章提出了一种持续的测试时域自适应方法(CoTTA),旨在应对非稳态和不断变化的目标领域环境,通过减少错误累积和防止灾难性遗忘,以实现源模型的长期适应。

103d80b9e03e4b8566616a4c8c4f311f.png

摘要

94ac89d1498858ad5929164685d9bda6.png

测试时域自适应旨在将源预训练模型适应到目标领域,而无需使用任何源数据。现有研究主要考虑了目标领域静态的情况。然而,在现实世界中,机器感知系统在非稳态和不断变化的环境中运行,目标领域的分布随时间可能会发生变化。现有方法主要基于自我训练和熵正则化,可能会受到这些非稳态环境的影响。由于目标领域随时间的分布变化,伪标签变得不可靠。噪声伪标签可能进一步导致错误累积和灾难性遗忘。为了解决这些问题,文章提出了一种连续的测试时自适应方法(CoTTA),它包括两个部分。首先,文章提出通过使用加权平均和数据增强平均预测来减少错误累积,这些预测通常更准确。另一方面,为了避免灾难性遗忘,在每次迭代中随机恢复一小部分神经元到源预训练权重,以帮助长期保留源知识。所提出的方法使网络中的所有参数能够进行长期自适应。CoTTA易于实施,并可轻松集成到现成的预训练模型中。作者在四个分类任务和一个分割任务上展示了CoTTA的有效性,用于持续测试时自适应,在这些任务中胜过了现有方法。

add249732292557c864d4accc4c0d0bd.png

引言

6e27bddad11c18e01df43090066e7747.png

测试时域自适应旨在通过在测试时从未标记的目标数据中学习,以使源预训练模型适应目标领域数据,源训练数据与目标测试数据之间存在领域偏移,因此需要适应才能实现良好的性能。例如,一个在晴天条件下训练的语义分割模型在雪夜条件下进行测试时可能会出现显著的性能下降。同样,一个预训练的图像分类模型在测试受传感器降级影响的损坏图像时也可能出现这种现象。在许多情况下,自适应需要以在线方式进行。因此,测试时域自适应对于在领域偏移情况下成功应用于实际机器感知应用至关重要。

现有的测试时域自适应方法通常通过使用伪标签或熵正则化来更新模型参数来处理源领域和固定目标领域之间的分布偏移。这些自训练方法对于来自同一个稳态领域的测试数据是有效的。然而,当目标测试数据源源不断地来自时刻变化的环境时,它们可能会变得不稳定。主要有两个原因:首先,在不断变化的环境下,由于分布偏移,伪标签变得更加嘈杂和不准确。因此,早期的预测错误更有可能导致错误积累。其次,由于模型长时间的适应新的数据分布,对于源领域的知识更难保持,从而导致灾难性遗忘。

为了应对这些问题,文章提出了一种连续的测试时自适应方法(CoTTA),以适应不断变化的环境。如图1所示,目标是从一个现成的源预训练模型开始,不断地使其适应当前的测试数据。假设目标测试数据是从一个不断变化的环境中流式传输的。预测和更新是在线执行的,这意味着模型只能访问当前数据流,而不能访问完整的测试数据或任何源数据。这个设定非常贴近实际的机器感知系统。比如在自动驾驶系统中,周围环境不断变化,从晴天到多云再到雨天等天气变化,车辆驶出隧道时摄像头突然曝光过度。在这种非稳态的环境下,感知模型需要能够实时适应并做出决策。

d75eb318929324f6c6198c990731df76.png

图 1 各种适用场景

该方法主要有两个作用,首先能够减少误差积累。在自我训练框架下,通过两种不同的方式来提高伪标签质量。一方面,由于平均教师预测通常比标准模型具有更高的质量,使用加权平均教师模型来提供更准确的预测。另一方面,对于具有较大域间隙的测试数据,使用增强平均预测来进一步提高伪标签的质量。其次,该方法能够保存源知识并避免遗忘。作者将网络中的一小部分神经元随机恢复到预训练的源模型。通过减少误差积累和保留知识,CoTTA能够在不断变化的环境中进行长期适应,并能够训练网络的所有参数,而以前的方法只能训练batchnorm参数。CoTTA可以轻松地集成到任何现成的预训练模型中,而无需它们的源数据。为了验证其有效性,作者在四个分类任务和一个分割任务上进行实验,实验结果表明,在这些任务中使用了持续的测试时自适应方法(CoTTA)的预训练模型,性能显著提高并超过了现有的方法。

文章的贡献总结为:作者提出了一种持续测试时自适应方法,可以有效地将现成的源预训练模型适应到不断变化的目标数据中。通过使用更准确的加权平均和数据增强平均伪标签来减少错误积累。通过明确保留源模型的知识来缓解了长期遗忘效应。所提出的方法显著提高了在分类和分割基准测试中的持续测试时自适应性能。

f78124b4a1316cf5b03296af8aefa2eb.png

问题定义

73daa4d1e79c0118b8275b421adb6bce.png

我们的目标是在推理时,针对一个持续变化的目标领域,以在线方式不断提高现有预训练模型f3bb5c82f155f344ba982fc432ce4dd2.png(其中θ是参数,已经在源数据e8c62113baa4862122944a8ada94fd00.png上训练过)的性能,而不需要访问任何源数据。目标领域的未标记数据XT是按顺序提供的,模型只能访问当前时间步的数据。在时间步t,目标数据4d7e827e20e79a47afb17f7924e800b5.png被提供为输入,模型4754f49707c91742807973a113a01f06.png需要进行预测1541b69147a2aeeb074457c43885ad70.png,并相应地进行自适应以适应未来的输入d76a4701b4dbfa4f6f864eb0d7950ae7.png9f634110fcaffdd19c9245329dfaaaae.png的数据分布不断变化,模型的性能是基于在线预测进行评估的。

这种设置很大程度上是由不断变化的环境中机器感知应用的需求所驱动的。例如,由于位置、天气和时间等因素,自动驾驶汽车的周围环境会不断变化。感知决策需要在线做出,模式需要调整。在线连续测试时间适应设置与现有适应设置之间的主要区别如下表所示,与以往专注于固定目标域的设置相比,作者考虑了对不断变化的目标环境的长期适应。

c8c94e3ecdac74c32a1d15a2dcd40d19.png

d22d80feffbe762f7a5828e6347ea813.png

方法

b701a861b80696ecb3763c6661644086.png

在线的连续测试时自适应方法采用现成的源预训练模型,并以在线方式使其适应不断变化的目标数据。通过使用加权平均和权重增强伪标签来减少错误积累。此外,为了帮助减少持续适应中的遗忘,该方法还保留了源模型中的信息,如图2所示。

f91b8cb48918480153eb76ec23c8d0bb.png

图 2 连续的测试时域自适应方法流程

方法1:加权平均伪标签

给定目标数据和模型2442f0bef4703de6079a1d8a6888957e.png,目标是通过模型的预测结果9481d380e1a1a8885cba90eb942a42fa.png与伪标签之间的交叉熵一致性来进行优化。这里的伪标签是一种用于训练的标签,通常由模型的预测结果生成。作者提到,如果直接使用模型的预测结果作为伪标签,这将导致在目标领域保持不变时有效,对于不断变化的目标数据,由于分布的变化,伪标签的质量可能会显著下降。

在深度学习训练中,通过对训练过程中的多个时间步的模型进行权重平均,通常可以得到比最终模型更准确的模型。这是因为权重平均可以减轻训练过程中的噪声和波动,从而提高了模型的鲁棒性和泛化性能。于是引入了一个称为教师模型的概念,在时间步t=0时,教师模型初始化为与源预训练模型相同。在时间步长t处,伪标签首先由教师85669b113538eb81403d69e623e32520.png生成。然后,学生93c8e80235d3bc6c4aa7a13fa4a3fb20.png通过学生和教师预测之间的交叉熵损失来更新:

df28aaf0ed5ace9b944187150095d603.png

其中,ea14ead2dd7f224363e0042e91015494.png为教师模型软伪标签预测中c类的概率,5c2aeb08791fc4b9909a78d036257a4e.png为主模型(学生模型)的预测。这种损失加强了教师和学生预测之间的一致性。

在学生模型权重89f90c803831da5ff308cd6c66315b8e.png通过上述公式更新后,便使用学生模型的权重来更新教师模型的权重,采用指数移动平均的方法进行更新:

7fd65ce491bb411e1a0f12fe98f9032f.png

其中α是平滑因子,控制了新权重与旧权重的混合程度。对输入数据af60066a9704c22ca6d70d2f4c8f20cd.png的最终预测值是2cf9bc918c5f28221c7fec41a85dae41.png中概率最高的类。

加权平均一致性有两个好处。一是通过使用更准确的加权平均预测作为伪标签目标,模型通过对高质量伪标签的训练在持续自适应过程中受到的误差积累较少。第二是平均教师预测编码了过去迭代中模型的信息,因此,在长期持续适应中不太可能遭受灾难性遗忘,并提高了对新的未知领域的泛化能力。

方法2:权重增强伪标签

在训练阶段数据增强已被广泛应用于提高模型性能。不同的数据增强策略可以手动设计或者通过搜索算法(如自动增强搜索)来确定,以适应不同的数据集和任务。在测试模型性能时,有时也会应用测试时增强(test-time augmentation)。这是一种在测试样本上应用数据增强变换的方法,它已被证明可以提高模型的鲁棒性,然而,通常情况下,测试时增强策略会在训练期间固定下来,而不考虑推理时领域(数据分布)发生变化的情况。在不断变化的环境下,测试分布可能发生巨大变化,这可能使增强策略无效。因此,作者考虑测试时的领域变化,并通过预测置信度来近似领域之间的差异。当领域之间的差异较大时,才会应用数据增强,以减少误差积累。

4f1df1423eb195604d9e7040ee1e701d.png

ca63642eec4fffc5b52662c5bbf0b4b8.png

其中d7fc4cd8784611a290bfd6e7ec07a840.png表示教师模型的增强平均预测,b9da63279b68cd2a5149adbd8b99f494.png为教师模型的直接预测,80ad7325382feefb54b88c39140149e3.png为源预训练模型对当前输入aeaf0d4fa10c868ef76f12fd57a50951.png的预测置信度,9dfe9f709d09f5f8c64915d306b58637.png为置信度阈值。通过计算源预训练模型的预测置信度来估计源域和当前域之间的差异。低置信度可能表示领域差异较大,而相对较高的置信度可能表示领域差异较小。因此,当预测置信度高于设定的阈值时,就会直接使用教师模型的原始预测作为伪标签,而不进行额外的数据增强。但是,当置信度较低时,会额外应用 N 次随机数据增强来提高伪标签的质量。这种方法的目的是根据模型对于当前输入的预测置信度来决定何时以及如何应用数据增强,以提高模型性能并适应领域差异。学生模型的预测与改进的伪标签之间的损失函数为:

e72cd4032e32da43dc91ff495a965bcc.png

c97388afa034d888152827b873767771.png表示学生模型对输入样本的预测中的类别 c 的概率,是从教师模型获得的改进的伪标签中的类别 c 的概率。通过最小化这个损失函数,模型试图使学生模型的预测尽可能接近改进的伪标签,从而更好地适应目标任务或领域。

方法3:随机恢复

虽然更准确的伪标签可以减少错误的积累,但长期自训练的持续适应会不可避免地引入错误并导致遗忘。特别是在数据序列中遇到强烈的域移位,因为强烈的分布移位会导致校准错误甚至错误的预测。在这种情况下,自训练可能只会强化错误的预测。并且在遇到困难的例子后,即使新数据没有严重偏移,模型也可能因为不断的适应而无法恢复。

为了进一步解决灾难性遗忘问题,作者提出了一种随机恢复方法,该方法明确地从源预训练模型中恢复知识。考虑在时间步长为t时,基于方程1的梯度更新后的学生模型227dfb102cced1e246e774d1829d8eee.png内的卷积层为:

4f734a22f71097ba210558b31df71d5c.png

其中*表示卷积运算,99b6886d7f3bfff1541574de100bcc01.png9cae15c50ef1ab9fb3868968587861d5.png表示该层的输入和输出,bee3018ebd475d024789f606990e1814.png表示卷积核。本文提出的随机恢复方法通过以下方式对权值进行更新:

3fcf7760f18e0eb78e5827b7608c5499.png

d27c57ac1a91d4139bbae2e15c572bd8.png

其中表示89c6ff621ca465df5907ffc944b425d0.png元素的乘法。p是一个小的恢复概率,M是与04ae8d6ac818fdb265525f50d0e44f5f.png形状相同的掩模张量,M 按照 Bernoulli 分布进行随机采样,M 的元素取值为0或1,取0的概率为 p,取1的概率为 1-p。掩码张量3e05a670dfba2ddbb8e8bd27f01f7bda.png决定中的哪个元素要恢复到源权重ba24936a6695fe3f411c6bef115472f1.png

随机恢复也可以看作是Dropout的一种特殊形式。通过随机地将可训练权值中的少量张量元素恢复到初始权值,避免了网络偏离初始源模型太远,从而避免了灾难性遗忘。

将精细伪标签与随机恢复相结合,形成了的在线连续测试时间适应(CoTTA)方法,如下所示。

6d6bd4f4744093a6737307d8f0e97053.png

有一个预训练模型(学生模型)2c7855e44311ee76df35cc7fb31110db.png和一个教师模型e50a7c99bba87568a7d840ee7c54a0a3.png,在初始化阶段时间步t=0时,教师模型初始化为与源预训练模型相同。输入时间步长t时的数据流63b5f8a6f22f981745ed1935a3f9a28e.png。首先,从教师模型中生成加权和增强平均伪标签(pseudo-labels),作为学生模型的训练目标。接着,将学生模型的预测结果和该伪标签,根据交叉熵损失更新学生模型。然后,使用学生权重通过指数移动平均更新教师模型的权重。其次,随机恢复一部分学生模型参数。最后,得到预测结果1e36e9562c918f66dafc1d997902c628.png,更新的学生模型10059e32ef2401d18187c6f348d3f4d7.png和更新的教师模型237098ed38e70e57a885d19a2f000e03.png。重复以上步骤,直到模型在目标领域上达到满意的性能。这个过程可以在不断变化的目标数据上持续进行,以适应领域分布的变化。

9663e087e9d95d55b10a380a5672c30d.png

实验

a3c62fd73234b691a0ac0b4fe85059d0.png

作者在五个不同的连续测试时间适应基准任务上评估了他们提出的方法,这些任务包括:CIFAR10-to-CIFAR10C(标准和渐变),CIFAR100-to-CIFAR100C,ImageNet-to-ImageNet-C,Cityscapses-to-ACDC。这些任务代表了不同类型的应用场景,包括图像分类和语义分割。CIFAR10C、CIFAR100C和ImageNet-C是包含15种严重程度为5级的损坏类型。

CIFAR10 to CIFAR10C的实验:

对于在线连续测试时间自适应任务,使用在CIFAR10或CIFAR100数据集上训练好的预训练网络。在测试期间,损坏的图像以在线方式提供给网络。在最大损坏严重等级5下评估各种基准模型。评估是基于遇到数据后立即的在线预测结果。CIFAR10和CIFAR100实验均采用在线连续测试时间自适应方案。

我们首先评估了所提出的模型在CIFAR10到CIFAR10C任务上的有效性。将我们的方法与纯源基准和四种流行的方法进行比较。结果如下表所示。b4541f541f71e51749817978c953ea10.png

CoTTA利用加权和增强平均的一致性,可以优于上述所有方法。错误率显著降低到16.2%。并且通过随机恢复方法,模型在长期适应的过程中性能不会下降。

 消融实验:此外作者还做了消融实验,如上表所示,通过使用教师模型的加权平均伪标签,错误率从20.7%降低到18.3%。这表明加权预测确实比直接预测更准确。通过使用多个增强来进一步细化权重平均预测,我们能够进一步将性能提高到17.4%。最后,通过随机恢复保留源知识,可以大大提高长期的预测。将错误率降低到16.2%。

  鲁棒性实验:通过逐渐改变15种不同程度的损坏类型图片设计10种不同的随机打乱的序列,使用这10个不同序列的平均错误率来评估这些方法,结果如下表所示。CoTTA优于其他方法,错误率只有10.4%。

55279624062976ba2b65c65db6a246b9.png

CIFAR100 to CIFAR100C实验:

在难度更高的cifar100 - cifar100c任务上对其进行了评估。实验结果如下表所示。

6fd4dd6e660a8c70e18004c72b8d120c.png

ImageNet to ImageNet-C的实验:

在实验中作者使用了标准的预训练resnet50模型。在十种不同的损坏顺序下对ImageNet-C实验进行了评估。在严重等级为5的10种不同的损坏类型序列上进行了ImageNet-to-ImageNet-C实验。如下表所示,CoTTA优于其他方法。

430cb942cc04c81601aff7f472af89e6.png

Cityscapes to ACDC的实验:

Cityscapes to ACDC是一个连续的语义分割任务,用它来模拟现实世界中的连续分布变化。为了尽可能重新访问类似环境的场景,并评估该方法的遗忘效果,作者将相同的序列组(四种条件)重复10次(即总共40次:雾- !夜- !雨- !雪- !雾…)。这也为长期适应性能的评估提供了依据。

在更复杂的连续测试时间语义分割Cityscapes to ACDC任务上评估了CoTTA。实验结果如下表所示。结果表明,该方法对语义分割任务也很有效,并且对不同的结构选择具有鲁棒性。与基线相比,我们提出的方法绝对提高了1.9%的mIoU,达到了58.6%的mIoU。

9e08774db4d3c5cd9e4605ae5548ec16.png

c89bac297e53f8256cb18948e580e8b0.png

结论

2d42f3679edf3ba334f98da7435059d7.png

对于在目标域分布随时间不断变化的非平稳环境中持续的测试时自适应产生的错误积累和灾难性遗忘问题,作者提出了一种新的CoTTA方法,该方法由两部分组成。首先,通过使用权重平均和增广平均预测来减少误差积累,这两种预测通常更准确。其次,为了保留来自源模型的知识,随机地将一小部分权重恢复到源预训练的权重。所提出的方法可以集成到现成的预训练模型中,而不需要访问源数据,作者在4个分类任务和1个分割任务上验证了其有效性。

4e817337e74aefd0f80d850f1fbcbcc5.png

Attention

76bc490b5e3c1f9f513f83bc7b50a28b.png

如果你和我一样是轨道交通、道路交通、城市规划相关领域的,可以加微信:Dr_JinleiZhang,备注“进群”,加入交通大数据交流群!希望我们共同进步!

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

当交通遇上机器学习

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值