DLTTA:跨域医学图像测试时间自适应的动态学习率

论文:https://arxiv.org/abs/2205.13723

代码:GitHub - med-air/DLTTA: [IEEE TMI'22] DLTTA: Dynamic Learning Rate for Test-time Adaptation on Cross-domain Medical Images

 发表刊物:TMI2022

摘要

测试时间自适应(TTA)已成为有效解决不同机构医学图像测试时间跨域分布偏移的重要课题。以前的TTA方法有一个共同的限制,即对所有测试样本使用固定的学习率。这样的实践对于TTA来说是次优的,因为测试数据可能是顺序到达的,因此分布转移的规模会经常变化。为了解决这个问题,我们提出了一种新的测试时间适应的动态学习率调整方法,称为DLTTA,该方法动态调节每个测试图像的权重更新量,以考虑其分布变化的差异。具体来说,我们的DLTTA配备了一个基于记忆库的估计方案,以有效地测量给定测试样本的差异。基于这个估计的差异,然后开发一个动态学习率调整策略,以实现每个测试样本的适当程度的适应。我们的DLTTA在视网膜光学相干断层扫描(OCT)分割、组织病理学图像分类和前列腺3D MRI分割等三个任务中得到了广泛的证明。我们的方法实现了有效和快速的测试时间自适应,并且与当前最先进的测试时间自适应方法相比具有一致的性能改进。

背景

DTTA[20]和ATTA[21]分别使用自编码器学习形状先验和对齐特征空间

首先,为了捕捉模型的渐进变化,我们维护了一个memory bank,以动态的方式缓存以前测试数据的特征和预测对。然后,为了估计新到来的测试样本的差异,我们从内存库中检索语义相似的样本,并计算缓存预测与派生差异的当前输出之间的Kullback-Leibler (KL)散度。根据估计的差异,我们进一步调整学习率,以实现对在线测试数据的适当适应。

相关工作

Test Time Adaption

仅基于顺序到达的未标记测试数据来调整深度模型最近引起了越来越多的兴趣。
与传统的UDA方法需要访问训练数据并收集足够数量的测试数据不同,测试时间自适应方法能够利用单个或批量测试数据提供的分布信息更新模型[16]-[18],[22]-[24]。

PTBN, TTT,TTT++,Tent,T3A

UDA, DTTA(学习先验形状)

Dynamic Learning Rate

为了在模型训练过程中调整学习率,

1) 一些预定义的学习率调度器,如线性阶跃衰减[34]、指数衰减、余弦/正弦退火[35],已经被提出并广泛用于训练深度神经网络。然而,单纯地衰减或循环学习率对于复杂的情况可能是不够的,例如面对分布变化的新环境时的测试时间适应能力。

2) 也有研究利用梯度来调整学习率。Maclaurin等人[36]引入了一种可逆学习技术来计算关于学习超参数的梯度。然后通过内部优化调整学习率。Baydin等人[37]计算相对于学习率的梯度,并在每次迭代时在线动态调整学习率更新。虽然这些方法可以通过梯度更新来动态调整学习率,但需要对标记数据进行准确的监督来计算更新的梯度信号,这在测试时间自适应场景中是无法实现的。

我们的想法是根据估计的分布移位实现动态学习率调整。如何以一种无监督的方式,仅根据模型参数和当前测试样本来测量差异是一个具有挑战性的问题,目前还没有得到解决。

最近一项需要源训练数据进行自适应的研究[38]提出,通过计算基于多个分类器输出一致性的置信度得分来推断“更容易”和“更难”的测试数据。虽然这项工作与我们在估计测试图像差异方面的想法相似,但它限制了具有多个不同分类器的特定网络架构,并且它使用识别的“更容易”的测试数据来产生伪标签,而不是设计动态自适应策略。Lee等人[39]利用训练数据存储的BN统计量与测试数据进行比较,作为分布移位的度量。由于在测试时模型参数不断更新,训练数据的BN统计量不能代表更新后的模型状态,导致差值估计不准确。我们提出的基于记忆库的差异测量方法既能捕捉到模型的发展过程,又能捕捉到测试数据的分布变化,从而为测试时间自适应的有效动态学习率调整提供最新的差异估计。

主要贡献

1)提出了一种新的动态学习率调整框架,用于跨域医学图像分析的测试时间适应。据我们所知,这是第一个探索动态学习策略的工作,以克服在测试时模型适应的推理数据的不同分布转移。

2)我们设计了一种新的基于记忆库的差异估计策略来模拟适应需求,在此基础上,对每个特定的测试样本进行在线调整学习率。

3)我们已经验证了我们的方法在三种不同的医学成像模式的2D或3D模型上的分类和分割任务的有效性。实验结果表明,我们的DLTTA适用于不同的网络架构,并且始终优于当前最先进的TTA方法。

方法

Test-time Adaptation Overall Framework 测试时间适应总体框架

在将模型应用于新的测试数据时,一个问题是新的测试样本可能遵循未知的数据分布,导致模型预测性能严重下降。以往的领域自适应方法需要收集足够数量的测试数据,而在实际应用中,测试样本通常以不同的分布位移顺序到达。更吸引人的解决方案是测试时间自适应,其目的是直接根据每个给出的测试样本不断地调整模型。

对于测试时自适应,我们用预训练模型fθs 通过一组参数θs和新的以不同的分布位移顺序到达在线的测试用例{x1,x2,...xt}

其中 fθs是用源训练数据 通过如下的经验风险最小化进行优化的↓其中Lm代表训练的有监督损失。 fθs 会在目标域 xt上表现不是很好因为它们之间的分布有差异。

然后,为了不断调整模型以获得更好的泛化性能,需要专门设计一个测试时间目标函数Ltt,根据每个测试样本提供的分布信息更新模型参数。对于出现在模型更新的第t次迭代中的测试样本xt,我们有↓ 其中η 代表了测试时自适应的学习率,θ1用θs初始化。使用更新后的模型fθt+1来获得测试样本xt的预测。

之前的研究已经提出了不同的测试时间目标函数Ltt,如旋转预测损失[16]、模型预测的熵最小化[17]、自编码器重建损失[21]等。

Dynamic Learning Rate on Test Data 测试数据的动态学习率 

以往的测试时间自适应方法在Eq.(2)中采用固定的学习率η,由于自适应性能对学习率敏感,需要慎重选择[17]。我们认为,所有自适应步骤的静态学习率不能准确地更新模型以克服测试数据的变化分布移位。因此,我们提出了一种动态学习策略来捕捉在测试时间适应过程中模型的变化和测试数据的不同移位程度。

1)Memory Bank Construction for Discrepancy Estimation 差异估计的记忆库构造

对于动态自适应,了解每一步的需求自适应程度是调节学习率的关键因素。为了测量适应程度,我们提出了一种基于记忆库的差异估计方法。构建记忆库存储由不断更新的模型提取的最新的特征表示 feature representation预测掩码prediction mask对。这些对反映了模型的变化,并可进一步用于计算与传入测试样本的距离,以估计分布移位程度。通过位移度测量,我们可以相应地调整测试时学习率。

具体而言,记忆模块M储存了K对key和value

如图2所示,key是被一个模型的特征提取器h计算的特征图;value是对应于分类器头部g生成的预测掩码。我们通过连续缓存新的(qk,vk)对来更新模块M。由于模型在测试时逐步更新,内存库中的早期元素不能指示测试数据上最新的模型性能。因此,我们只维护固定大小K的内存库,并在写入新对时保持先进先出(FIFO)原则。

然后测量一个新的测试样本xt的差异,我们的目标是从memory bank M 中去检索一个support set R,,包含与当前测试样本xt在语义上相似的元素的特征和预测。这是通过基于每个键与查询样本h(xt)的特征之间的L2距离计算d近邻(D-nearest neighbors)来实现的。 由于键维护语义级上下文特征,支持集R保留了关于高级信息的内部协议,例如,图像中的对象类别。R中预测的集合可作为xt的参考预测,为:

图像xt的预测误差由下式导出↓。其中Lkl代表着KL散度然后fθt(xt)代表模型fθt生成的xt的预测

2)Adaptive Learning Rate Adjustment

通过Eq.(4)测量的差异,我们的直觉是,高差异表明需要在很大程度上弥补显著的差距,低差异需要较小的适应。因此,我们建议根据估计的差异动态调整自监督学习率。为了以更一般的方式描述我们的方法,我们采用batch-wise formulation来考虑测试样品可能一个接一个或一批接一批地到达。单个测试图像对应于批大小1。

给定一批测试数据,在每个测试时间适应步骤中,我们计算总差异为↓,其中B是batch size, 捕获分布在批处理预测中的一般差异。

然后,通过一个函数得到批处理动态学习率,根据积分差直接输出测试时间适应任务的学习率↓。其中α 进一步缩放学习率,可以经验地设置为使用源训练图像进行模型优化的值。

我们提出的测试时动态学习率调整可以很容易地部署到任何网络架构和自监督目标函数中,以改善测试时的适应过程。

由于特征表示和预测掩码是自然计算并存储在内存库中,所以测试时的计算成本主要来自于参考预测p^t 计算支持集的检索。我们通过保持较小的检索大小(例如8)来提高效率。在实验中分析了检索尺寸(retrieval size)的效果,结果表明,较小的尺寸已经可以取得较好的效果,增加尺寸(例如增加到20)并不会表现出更高的性能。

Learning Process and Implementation Details 学习过程和实施细节

1) Learning Process

要更新的模型首先通过从源训练数据学习到的参数进行初始化。

对于即将到来的测试样本,该模型首先进行前向传递,获得特征表示和预测掩码对,从记忆库中检索语义相似的元素,并计算特定测试样本的差异和动态学习率。然后使用导出的学习率和测试时间目标函数对模型进行更新,以达到期望的自适应。需要注意的是,在之前具有不同测试时间适应目标的工作中,只适应了部分模型参数[16],[17],[21]。

在一步自适应之后,该模型进行另一次前向传递,以获得对当前测试样本的预测并更新内存库的新元素对。

该模型同 每个顺序到来的测试样本 与前一个过程迭代地更新 。在构建记忆库之前,学习率保持为初始值。DLTTA伪代码如算法1↓所示。

实验

数据集

视网膜OCT图像分割:OCT

组织病理学图像分类:Camelyon17 [49]–[51] dataset

前列腺MRI图像分割:T2-weighted MRI volumes: NCI-ISBI13 [53], I2CVB [55],
and PROMISE12 [54]

实现细节

backbone

OCT 数据集: U-net[40], 

ImageNet for histopathological image classification: DenseNet-121 [42]

prostate MRI image segmentation: 3D U-Net

我们的2D U-Net和3D U-Net的编码器包含4个Convolution- batchnorm - relu块,它们连续地降低了图像分辨率并将特征通道尺寸增加了一倍。

然和特征通过bottle neck层输入到解码器中,该解码器具有4个转置卷积块对中间特征映射进行上采样。

所有的二维卷积都使用内核大小3×3;3D卷积使用内核大小3×3×3

编码器特征在每个阶段skip-connection到解码器,并且每个块的特征通道尺寸为[16,32,64,128]。对于分割网络,我们将编码器和bottle neck 层作为特征提取器,将解码器作为预测头。

DenseNet-121的架构遵循Torchvision[44]库中的实现,其中包括
4个紧密连接的块(densely connected blocks),然后是全局池化global pooling和完全连接的分类层fully connected classification layers。对于每个任务,我们的方法和其他比较方法的网络骨干是相同的,以确保公平的比较。

训练细节

source

对于源图像的模型训练,模型从头开始训练100次

对于分割和分类任务,Adam优化器和学习率初始化为1e-3和3e-4。

test-time adaption

对于我们的方法和所有比较基线的测试时间适应,我们对每批测试数据执行一步适应,分割批大小为1,分类任务批大小为200

我们在K/B步里存储最新的键和值对,我们经验地设置K/B为分割20分,分类4分。

分割任务的检索大小D设置为8,分类任务的检索大小D设置为12。

参考[18],对所有方法的试验数据重新进行BN统计。

框架使用Pytorch 1.7.0实现,并在一个NVIDIA TitanXp GPU上进行了训练。

实验结果

消融实验 

这回主要看的是方法,消融实验以后再补笔记叭

Effectiveness of Learning rate adjustment

Adaption Stability

Effect of Initial Learning Rate

Influence of Retrieval Size

Influence of Image Orders

Effect of Similarity Metrics

总结

讨论

本文解决了具有挑战性的测试时自适应问题,旨在通过学习测试时提供的推理样本,将深度模型推广到未知数据分布。本文提出了一种测试时间适应的动态学习率策略,旨在根据估计的预测差异动态调整模型更新的步长。

我们还实现了TTA方法的图像特定变体。每个测试图像都经过10次迭代更新,以允许充分的适应。在表6中,我们得到了与[16]相似的观察结果,连续自适应优于图像特定版本,计算量减少了10倍。通过对在线模型更新的动态调节,使该方法更有利于连续自适应。

现有的TTA方法要么对每个测试图像进行多个自适应步骤,如ATTA,要么进行一次梯度更新,如TTT和Tent。对于每个测试图像的自适应,我们只进行了一次梯度更新,因为得益于所提出的动态学习率调整,在线自适应过程显式模块化,一步就可以实现有效的自适应。我们还尝试在每个测试图像上为多个步骤更新模型。表7给出了我们的方法在1、4和8次梯度更新时的结果。我们可以看到,多次更新不能获得一致的改进,但需要更多的计算成本。结果表明,该方法只需一次梯度更新即可实现有效的测试时间自适应,计算成本更低,但性能比以往方法有所提高

我们的方法的一个局限性是,在构建记忆库之前,很难估计差异,因此没有在前K个图像的测试时间自适应中添加动态学习率调整。在不使用动态学习率进行自适应的情况下,我们的方法对前K张图像的性能与之前取决于动态学习率部署到的目标函数的TTA方法方法相同

如何对前几张测试图像的动态学习率调整获得更好的差异估计,并在一开始就实现更快、更有效的自适应,将是未来有趣的工作。一种可能的解决方案是在内存库中存储一些训练数据特征统计,并通过比较训练和测试特征统计以及预测来调整学习率。

总结

提出了第一种测试时间自适应的动态学习率调整方法,使模型能够有效地适应测试数据间分布的变化。我们提出了一种基于记忆库的差异测量方法,同时考虑了测试时间模型的渐进变化和测试数据分布的变化,并进一步实现了基于估计差异的动态学习率调整。该方法对测试时间自适应目标具有有效性和通用性,不需要改变网络设计,可以很容易地应用于改进不同的测试时间自适应方法。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值