Adaptive Task-Aware Refining Network for Few-Shot Fine-Grained Image Classification
Abstract
- 少样本细粒度(FSFG)图像分类的主要挑战是学习少量标记样本的区别特征表示。为了应对这一挑战,任务感知的少样本学习方法被引入。然而,现有方法只关注如何将任务信息与特征表示相关联,而忽略了两个关键问题。第一个问题是如何用较少的标记样本获得准确的任务表征,从而准确地得到与任务相关的特征区域。第二个问题是如何降低在获取特征区域的过程中引入的背景噪声的影响,以缓解少镜头设置下的过拟合问题。为了解决这些问题,我们提出了自适应任务感知精炼网络(ATR-Net)。与以前使用类中心作为任务表示的方法不同,ATR-Net使模型能够通过与所有局部特征的补丁进行交互来自适应地选择特定于任务的信息,从而产生更准确的任务表示。此外,设计了通道区域感知模块(CRAM)和精细过滤模块(RFM ),以更好地获取任务级信息和实例级信息,从而克服背景噪声的影响。我们在四个公开的细粒度数据集上进行了大量的实验。实验结果表明,该方法具有较好的性能。我们的代码可以获得: Anonymized Repository - Anonymous GitHub
- 论文地址:Adaptive Task-Aware Refining Network for Few-Shot Fine-Grained Image Classification | IEEE Journals & Magazine | IEEE Xplore
- 主要目标是解决小样本细粒度图像分类中判别特征学习的问题,具体包括准确获取任务表示和减少背景噪声影响。为了实现这个目标,研究者提出了 ATR-Net,包含 TRGM、CRAM 和 RFM 等模块。需要逐个分析这些模块的设计思路和作用,比如 TRGM 如何通过局部特征交互生成任务表示,CRAM 如何利用多通道函数获取任务级信息,RFM 如何通过空间注意力过滤噪声。
- 针对小样本细粒度图像分类(Few-Shot Fine-Grained Image Classification, FSFG),解决两大核心问题:准确获取任务表示:现有方法通过全局平均池化压缩支持集特征,导致任务表示受背景噪声干扰,无法捕捉细粒度差异(如鸟类的喙部、羽毛纹理等局部特征)。减少背景噪声影响:细粒度图像类间差异小、类内差异大,提取任务相关特征时易引入背景噪声,导致模型过拟合。
- 任务感知机制:不同任务(如区分 “信天翁” 与 “海雀”)的关键特征不同(如翅膀形状 vs. 喙部弯曲度),模型需自适应选择与当前任务相关的局部特征。细粒度任务的关键特征随任务变化(如 “鸟 vs. 狗” 关注翅膀,“信天翁 vs. 海雀” 关注喙部),需通过局部特征交互动态生成任务表示,而非固定类中心。
- 分层特征过滤:先通过任务级模块定位粗粒度相关区域(如 “鸟类身体”),再通过实例级模块筛选细粒度判别区域(如 “喙部细节”),逐步过滤噪声。先通过通道级注意力定位任务相关区域(CRAM),再通过空间注意力提取实例级判别令牌(RFM),两层过滤有效减少噪声。
- ATR-Net 通过 “动态任务表示生成→任务级区域筛选→实例级特征精炼” 的三层架构,系统性解决了小样本细粒度分类中的核心挑战。其核心价值在于将人类细粒度识别的 “先定位关键区域,再聚焦细微差异” 的逻辑转化为可计算的模块组合,为后续研究提供了任务自适应与分层特征过滤的方法论。
INTRODUCTION
-
细粒度图像分类 旨在区分同一超级类别中的各种子类别(例如,各种鸟类)。如图1所示,由于类别之间的细微差异和类别内的大差异的特征,这些子类别的识别比一般的图像识别更困难。大多数现有的细粒度方法严重依赖于大型数据集的可用性。然而,在现实场景中,注释大量细粒度的子类别是非常昂贵的,也是不实际的。例如,在医学诊断中,准确识别疾病的各种亚型通常需要医学专业人员的专业知识。许多其他领域也面临子样本稀缺的问题,包括濒危物种保护领域。因此,许多研究人员将研究重点转移到少样本学习 ,探索如何用少量标记样本解决细粒度图像识别问题。
-
将传统的少样本学习方法应用于细粒度场景是一种直观的方法。其中,使用元学习机制和度量学习框架来辨别查询样本和有限数量的支持样本之间的相似性是有效的。具体来说,这些方法通过固定的嵌入函数将图像映射到嵌入空间。然后,通过距离函数(例如,欧几里德距离或余弦距离)计算查询图像和支持图像之间的相似度。然而,这种固定的嵌入功能在任务之间共享一组公共区域,这可能无法解释类别之间的差异。例如,如图1 (a)和(b)所示,在“狗”的上下文中识别“鸟”的方式与在“鸡”的上下文中识别“鸟”的方式显著不同。
-
-
图一。(a)和(b)是两个粗粒度的图像任务。©和(d)是两个细粒度的图像任务。与粗粒度图像任务相比,细粒度图像之间的可变性更小,更难识别。
-
-
在前一种情况下,机翼的存在是一个至关重要的区别特征,而在后一种情况下,它是无关紧要的。换句话说,语义特征的重要性随着任务而变化,并且模型应该具有任务感知能力。
-
为了使模型获得任务感知能力,研究人员探索了适应性嵌入的概念 。这些方法主要集中于增加每个任务中的特征信息,然而它们仍然严重依赖于固定的嵌入模型,并且受到先前学习的任务的限制。针对这一限制,已经提出了完全自适应的方法 。这些方法压缩来自任务(即,一集)中的所有支持图像的信息,以构建特定于任务的表示。该表示随后被用作核函数来过滤支持和查询样本。这种方法显著提高了模型识别和关注与手头任务最相关的特征区域的能力。
-
然而,上述方法只关注如何将任务信息与特征表示相关联,而忽略了两个关键问题。第一个是不准确的特定任务表征。如图1 ©和(d)所示,细粒度图像的特征在于低类间变化和高类内变化。这个特性模糊了细粒度图像中特定任务之间的区别,从而降低了它们的可识别性。传统方法通过简单地压缩(即,全局平均池化)任务中的所有支持图像来导出当前任务信息,并将类嵌入的中心视为任务特定的表示,如图2 (a)所示。然而,这种方法非常容易受到背景噪声的影响,并且不能准确地捕捉每个任务的关键信息。这个问题在细粒度图像中尤其明显,因为它们的类间可变性很小。因此,我们首次采用一种可学习的方法来获得更精确的任务表示,并将它们与图像特征的每个通道相结合,以获得更精确的任务特征区域。据我们所知,以前没有做过类似的工作。
-
-
图二。模型动机图。(a)和(b)分别描述了通过传统方法和我们提出的方法获得任务特定表征的方式。传统方法压缩当前任务的所有支持样本,并将所有类的中心作为特定于任务的表示。我们的方法使网络能够通过支持集的局部特征之间的相互作用自适应地选择与当前任务相关的语义信息。
-
-
第二个是在获取与当前任务相关的特征区域的过程中引入的背景噪声,这在少样本设置中容易引起模型过拟合问题。如图1 ©和(d)所示,当识别不同的鸟类时,我们不仅需要过滤掉不相关的对象(任务级,例如岩石),还需要对鸟类特有的细微区别特征敏感(实例级,例如喙)。由于低的类间可变性,当获取任务相关的特征区域时,可能引入额外的背景噪声。在以前的工作中,特征重建或特征对齐被用来缓解这个问题。这些方法虽然有效,但算法复杂,模型训练过程缓慢。因此,设计一种简单有效的方法是很有必要的。
-
针对上述挑战,本文提出了一种新的方法,称为自适应任务感知精炼表示网络(ATR-Net)。其架构如图3所示。为了更准确地提取特定任务的表示,我们提出了任务表示生成模块(TRGM),如图2 (b)所示。具体来说,我们将一个任务中的所有支持样本划分为多个局部特征补丁。随后,我们采用网络g(θ)来促进这些贴片之间的相互作用。通过利用注意机制和梯度优化,g(θ)自适应地“选择”代表当前任务的不同范围的语义信息。最后,我们综合所有的语义信息来确定最终的任务表示。
-
-
图3。在三路单触发设置下的ATR-Net框架。它分别由特征嵌入模块、任务感知信道过滤模块和细化过滤模块组成。
-
输入:支持集(Support Set)和查询集(Query Set)图像,通过嵌入模块(如 Conv-4/ResNet-12)提取特征。核心模块:三层级联结构,逐步细化特征表示:任务表示生成模块(TRGM):动态生成任务特定表示,解决 “准确任务表示” 问题。通道区域感知模块(CRAM):筛选与任务相关的通道和空间区域,获取任务级信息。精炼过滤模块(RFM):通过空间注意力提取实例级判别特征,过滤背景噪声。
-
import torch import torch.nn as nn import torch.nn.functional as F from models.backbone_res12 import ResNet from models.conv4 import ConvNet from models.TACF import TACF import numpy as np class ATR_Net(nn.Module): def __init__(self, args, resnet=False, mode=None): super().__init__() # 模型的运行模式,如 'fc', 'encoder', 'ATR_Net' self.mode = mode # 命令行参数,包含数据集、训练参数等信息 self.args = args # 是否使用 ResNet 作为编码器 self.resnet = resnet # 可学习的缩放因子,用于调整相似度矩阵的尺度 self.scale = nn.Parameter(torch.FloatTensor([1.0]), requires_grad=True) # 令牌数量 self.num_token = args.num_token if resnet: # ResNet 编码器的输出维度 self.encoder_dim = 640 # 使用 ResNet 作为编码器 self.encoder = ResNet() print("This is ResNet") else: # ConvNet 编码器的输出维度 self.encoder_dim = 64 # 使用 4 层的 ConvNet 作为编码器 self.encoder = ConvNet(4) print("This is ConvNet") print("descriptor:", args.num_descriptor, "num_token:", self.num_token) # 初始化 TACF 模块,用于任务自适应的上下文融合 self.tacf = TACF(args, self.encoder_dim, 2, 1, args.num_descriptor, self.num_token) # 全连接层,用于分类 self.fc = nn.Linear(self.encoder_dim, self.args.num_class) def forward(self, input): if self.mode == 'fc': # 全连接层模式,对输入特征进行分类 return self.fc_forward(input) elif self.mode == 'encoder': # 编码器模式,提取输入的特征 x = self.encoder(input) return x elif self.mode == 'ATR_Net': # ATR_Net 模式,进行任务自适应的特征学习和分类 spt, qry = input # 通过 TACF 模块处理支持集和查询集的特征 ch_task_spt, ch_task_qry = self.tacf(spt, qry) return self.metric(ch_task_spt, ch_task_qry) else: # 未知模式,抛出异常 raise ValueError('Unknown mode') def fc_forward(self, x): # 对输入特征在最后两个维度上求平均 x = x.mean(dim=[-1, -2]) # 通过全连接层进行分类 return self.fc(x) def metric(self, token_support, token_query): # 对查询集特征在最后一个维度上求平均 qry_pooled = token_query.mean(dim=[-1]) # 对支持集和查询集的特征进行归一化 token_spt = self.normalize_feature(token_support) token_qry = self.normalize_feature(token_query) # 获取支持集的类别数 way = token_spt.shape[0] # 获取查询集的样本数 num_qry = token_qry.shape[0] # 扩展支持集和查询集的维度,以便进行相似度计算 token_spt = token_spt.unsqueeze(0).repeat(num_qry, 1, 1, 1) token_qry = token_qry.unsqueeze(1).repeat(1, way, 1, 1) # 对扩展后的支持集和查询集特征在最后一个维度上求平均 spt_attended_pooled = token_spt.mean(dim=[-1]) qry_attended_pooled = token_qry.mean(dim=[-1]) # 计算支持集和查询集特征的余弦相似度矩阵 similarity_matrix = F.cosine_similarity(spt_attended_pooled, qry_attended_pooled, dim=-1) # 对相似度矩阵进行缩放 logits = similarity_matrix * self.scale if self.training: # 训练模式下,返回相似度矩阵和全连接层的输出 return logits, self.fc(qry_pooled) else: # 测试模式下,只返回相似度矩阵 return logits def normalize_feature(self, x): # 对输入特征进行归一化,减去每个特征的均值 return x - x.mean(1).unsqueeze(1)
-
假设输入图像的维度为
(batch_size, 3, 84, 84)
。如果使用ResNet
作为编码器,输出维度为(batch_size, 640, H', W')
,其中H'
和W'
是经过多次卷积和下采样后的高度和宽度。如果使用ConvNet(4)
作为编码器,输出维度为(batch_size, 64, H', W')
。 -
TACF 模块阶段:支持集特征
spt
和查询集特征qry
,维度分别为(way * shot, encoder_dim, H', W')
和(way * query, encoder_dim, H', W')
。经过 TACF 模块处理后,输出的支持集和查询集特征ch_task_spt
和ch_task_qry
的维度保持不变。 -
度量计算阶段:支持集特征
token_support
和查询集特征token_query
,维度分别为(way * shot, encoder_dim, H', W')
和(way * query, encoder_dim, H', W')
。qry_pooled
的维度为(way * query, encoder_dim)
。token_spt
扩展维度后为(way * query, way * shot, encoder_dim, H', W')
。token_qry
扩展维度后为(way * query, way * shot, encoder_dim, H', W')
。spt_attended_pooled
和qry_attended_pooled
的维度均为(way * query, way * shot, encoder_dim)
。similarity_matrix
的维度为(way * query, way * shot)
。 -
ATR_Net
模型通过编码器提取输入图像的特征,然后使用 TACF 模块进行任务自适应的上下文融合,最后通过度量计算和全连接层进行分类。在数据流转过程中,不同阶段的维度变化是为了适应模型的结构和计算需求。
-
-
为了最小化背景噪声的干扰,我们分别提出了信道区域感知模块(CRAM)和精细滤波模块(RFM)。CRAM通过使用多通道函数来提取与当前任务最相关的语义信息,从而获得更精确的任务感知特征区域。之后,RFM在任务感知特征区域内进一步定位重要的判别区域。它通过使用具有注意机制的多个记号化器函数来实现这一点,这些记号化器函数采用空间注意机制来为每个输入计算多个空间权重图,并且这些图可以学习特征中的重要区分区域。因此,我们可以通过仅使用几个关键的视觉表征来准确地获得图像的区别性表示。在四个细粒度数据集上的大量实验表明,我们提出的方法达到了最先进的性能。总之,我们的贡献有四个方面:
-
我们提出了一个创新的网络,称为自适应任务感知精炼表示网络(ATR-Net)。该方法首先根据不同的任务区分任务相关的特征区域,然后在这些区域中识别最具区分性的特征,以准确地确定当前图像的类别。这更符合人类的行为习惯。
-
我们提出了任务生成模型(TRGM)和通道区域感知模型(CRAM),首次将重点放在准确获取任务表征并将这些任务与图像特征的每个通道相关联,而不仅仅是获取与任务相关联的空间区域。
-
我们提出了精细滤波模块(RFM ),该模块自适应地滤除背景中的噪声信息,并且仅使用几个关键的视觉表征就可以表示图像特征中的重要特征信息。这种方法比以前复杂的特征重建和特征对准方法更简单和更有效
-
我们在四个流行的细粒度数据集上进行了全面的实验。实验结果表明,ATR-Net在四个广泛使用的基准数据集上取得了最好的性能。
-
-
任务表示生成模块(TRGM):传统方法使用类中心(如全局平均池化)作为任务表示,忽略局部特征交互。TRGM 通过视觉 Transformer 机制,让任务令牌(Task Tokens)与支持集的局部特征块交互,自适应选择关键语义信息。
- 将支持集特征拆分为局部补丁(Local Patches),与可学习的任务令牌拼接后输入 Transformer。通过多头自注意力机制,任务令牌从局部补丁中 “选择” 最能代表当前任务的特征(如区分 “狗品种” 时聚焦耳朵形状补丁)。
- 作用:生成多维度任务表示,捕获细粒度任务特异性(如不同子类的关键局部特征组合)。
-
通道区域感知模块(CRAM):不同通道可能对应不同语义(如颜色通道 vs. 纹理通道),通过通道级注意力筛选与任务相关的区域。
- 对每个通道,计算任务表示与通道特征的点积,生成通道权重图(Sigmoid 激活)。加权聚合通道特征,保留任务相关区域(如 “鸟类分类” 中激活羽毛纹理通道,抑制背景颜色通道)。
- 作用:从全局特征中定位粗粒度任务相关区域,减少背景干扰(如剔除岩石、树枝等无关区域)。
-
精炼过滤模块(RFM):细粒度差异常存在于局部细微结构(如 “汽车型号” 的车灯轮廓),需进一步筛选实例级判别区域。
- 初始化多个令牌化函数(Tokenizers,如卷积层 + Sigmoid),生成空间注意力图,每个注意力图对应一个关键区域(如 “鸟喙”“汽车轮毂”)。通过全局平均池化压缩空间维度,仅保留少数关键视觉令牌(Visual Tokens)作为最终特征。
- 作用:从任务相关区域中提取最具判别力的局部特征,减少冗余信息(如仅用 3-5 个令牌代表关键细节)。
RELATED WORK
A. Few-Shot Image Classification
- 现有的少样本学习方法可以大致分为三大类:基于优化的方法、基于增强的方法和基于度量的方法。基于优化的学习方法旨在学习一个好的优化器来更新模型参数。例如,MAML 及其许多变体旨在学习良好的模型初始化,以便学习者可以快速适应新的任务。基于增强的方法专注于从基类学习生成器,并利用其他技术来生成额外的样本。这种方法通过提供补充数据,有效地解决了少击学习中遇到的数据稀缺问题。基于度量的方法是少样本学习中的主导方法。这些方法通过各种度量(例如欧几里德距离)。在早期阶段,大量的努力致力于利用固定的度量或可训练的模块来建立有效的度量函数,例如ProtoNet Network ,Relation Network 。
- 近年来,基于GNN的相似性度量也开始被提出,其优点是模型可以学习丰富的关系结构。现有的少样本学习方法通常强调学习全局特征表示,由于捕捉微妙的区别特征的挑战,这可能不太适合于细粒度图像。相比之下,我们提出的方法更强调从图像的局部特征中提取信息,使其对细粒度图像更相关和有效。
B. Task-Aware Few-Shot Image Classification
- 近年来,已经提出了许多方法来解决少样本图像中任务意识的挑战。参考文献[Finding taskrelevant features for few-shot learning by category traversal]提出了一种类别遍历模块(CTM ),能够生成特定于任务的特征掩码,以选择最相关的特征维度。参考文献[Learning instance and task-aware dynamic kernels for few-shot learning]使用支持集的实例级特征和全局上下文特征来产生动态实例核和动态任务特定核,然后构建动态网络。有的提出了一种获取特定任务参数的方法,将这一概念与局部特征相结合。他们提出的任务感知零件挖掘网络(TPMN)专门为提取基于零件的特征时使用的过滤器生成参数。提出了TAFE网络,其具有任务感知元学习器,为标准预测网络中的特征嵌入层生成权重。然而,上述工作主要集中在粗粒度图像上,并且他们的方法涉及压缩当前训练样本的所有特征以获得任务表示。不幸的是,这种压缩过程使得难以准确提取每个任务的表示,这进一步影响了模型的泛化性能。通过综合当前任务的所有局部特征信息并动态学习它们,我们可以更准确地获得特定于任务的表征。
C. Few-Shot Fine-Grained Image Classification
-
自从[Piecewise classifier mappings: Learning fine-grained learners for novel categories with few examples]将FSFG图像分类开创性地引入少样本学习领域以来,计算机视觉研究界已经提出了许多创新方法来解决这个具有挑战性的问题。FSFG图像分类的早期研究主要集中在使用各种度量学习方法来增强输入图像对的关系匹配。例子包括TAON 和BSNet 模型。随着该领域的发展,后来的研究探索了注意力机制,以从细粒度图像中提取更具区分性的特征。参考文献利用 Transformer 模型来创建集合到集合的函数,有助于对图像集合之间的交互进行建模,以实现每图像的协同适应。有的参考文献介绍了双注意力网络,采用注意力机制来选择性地保留有用的深度描述符,将它们集成到多实例学习框架中,以利用细粒度图像部分之间的相关性。
-
最近,人们越来越关注FSFG图像中目标的空间对齐。这种转变导致了创新的方法,如特征重构网络(FRN) ,该网络使用最优重构权重的封闭形式的解决方案从支持样本特征重构查询样本特征。然而,吴等人指出了单向重建方法的局限性,特别是它不能适应图像的类内变化。为了克服这一点,他们提出了一个双向重建策略,称为BiFRN ,旨在通过合并支持到查询和查询到支持的重建来改善对齐。此外,李等认为当简单地使用通常嵌入在模块中的基本特征时,对于细粒度图像分类至关重要的区别性局部特征在重建中没有被很好地考虑,因此提出了LCCRN。
-
在现有的方法中,明显缺乏专门为FSFG设计的任务感知方法。一般的任务感知技术往往不能满足FSFG的独特需求。我们的方法旨在通过针对FSFG分类的特定需求提供专门的解决方案来解决这一缺陷。
METHOD
- 在这一节中,我们首先定义FSFG学习的任务。随后,我们描述了我们提出的方法,随后是对每个组件的概述和深入描述。
A. Problem Definition
-
对于FSFG图像分类任务,有三组具有不相交标签的数据可用,即基集Dbase、验证集Dval 和 新集Dnovel,以及{Dbase∩Dval∤Dnovel=∅}。在训练阶段,从Dbase中随机选择一组任务 T i i = 1 I {T_i}^I _{i=1} Tii=1I 。每个任务都包含一个支持集 X S 和一个查询集 X Q。X S从Dbase中随机选择N个类,每个类包含M个样本。X S中总共有N×M个样本,称为N路M-shot。同样,我们从相同的N个类中选择Q个样本来形成查询数据集 X Q X ^Q XQ,在 X Q X ^Q XQ 中总共有N×Q个样本, X S ∩ X Q = ∅ X ^S∩X^ Q=∅ XS∩XQ=∅ 。FSFG任务的目标是通过 X S X ^S XS 估计 X Q X ^Q XQ 中样本的类别。
-
许多现有的 FSFG 方法采用元学习范式 。具体来说,该模型是通过从集合 { T i } i = 1 I \{T_i\}^I _{i=1} {Ti}i=1I 中随机选择一个任务来训练的,我们将每个随机选择的任务称为一个事件。在参数选择和测试阶段,任务分别从 Dval 和 Dnovel 以相同的方式形成。
B. Overview
-
如图3所示,我们的网络由五个不同的模块组成。第一个是用于提取图像特征的嵌入模块 fθ。这可以是卷积网络或残差网络。第二个模块是任务表示生成模块(TRGM)gφ,它采用空间注意力机制来促进支持图像的局部特征之间的交互。该模块可以自适应地“选择”有效表示当前任务的相关语义信息,从而生成多个准确的任务表示。第三个是通道区域感知模块(CRAM)cγ,它获取图像的任务级信息。该模块通过多通道函数找到与当前任务最相关的特征区域。第四个是精炼滤波器模块(RFM)rτ,它获取图像实例级信息。该模块采用空间注意力机制为每个输入计算多个空间权重图,这些图可以学习特征中重要的判别区域。最后,第五个模块是分类器模块,它通过距离函数计算最终获取的支持和查询样本特征之间的相似性。
-
class ATR_Net(nn.Module): def __init__(self, args, resnet=False): super().__init__() self.args = args # 选择编码器(ConvNet或ResNet) self.encoder = ResNet() if resnet else ConvNet(4) self.tacf = TACF(args, self.encoder_dim, num_descriptor=args.num_descriptor) # TACF模块 self.fc = nn.Linear(self.encoder_dim, args.num_class) # 分类头 def forward(self, input): if self.mode == 'encoder': return self.encoder(input) # 特征提取模式 elif self.mode == 'ATR_Net': spt, qry = input # 支持集与查询集特征 # 生成任务特定区域特征 ch_task_spt, ch_task_qry = self.tacf(spt, qry) return self.metric(ch_task_spt, ch_task_qry) # 度量计算 elif self.mode == 'fc': return self.fc_forward(input) # 分类模式(辅助损失) def metric(self, token_support, token_query): # 归一化与余弦相似度计算 token_spt = self.normalize_feature(token_support) token_qry = self.normalize_feature(token_query) # 广播机制计算所有查询-支持对的相似度 similarity_matrix = F.cosine_similarity( token_spt.unsqueeze(0), token_qry.unsqueeze(1), dim=-1 ) * self.scale # 缩放因子 return similarity_matrix
-
通过
mode
参数切换功能(特征提取、任务处理、分类),训练时结合多任务损失优化。利用余弦相似度计算类别相似性,scale
参数可学习,放大差异。 -
数据集加载(dataset_builder) → 任务采样(CategoriesSampler) → 数据输入编码器(ConvNet/ResNet) → 特征提取(B, C, H, W) → TACF模块处理(生成任务特定区域特征) → 度量计算(余弦相似度) → 损失计算(epi_loss + 辅助损失) → 反向传播优化
-
关键参数:
way
:类别数(N-way,如 5);shot
:每类支持样本数(K-shot,如 1/5);query
:每类查询样本数(如 15) -
# 配置参数(修改args中的dataset、way、shot等) python task_token_train.py --dataset cub --way 5 --shot 1 --resnet False
-
在
TACF.py
的task_specific_region_selector
中添加注意力图可视化代码,观察任务特定区域的激活情况。监控epi_loss
(小样本任务损失)和loss_aux
(辅助分类损失)的平衡,调整lamb
参数。
-
-
ATR-Net 通过任务自适应的注意力机制和度量学习,有效解决了小样本学习中特征泛化不足的问题。核心代码集中在
ATR_Net.py
(整体架构)、TACF.py
(注意力模块)和conv4.py
/backbone_res12.py
(编码器)。数据流转通过自定义采样器确保小样本任务的生成,训练时结合多损失函数优化,提升模型在少样本场景下的分类能力。
C. Task Representation Generation Module
-
我们首先定义特征图X。对于N路M样本分类任务,我们 将 N×(M+Q)个样本输入到嵌入模型fθ中以提取它们的特征,支持集和查询集中的特征分别表示为 X S 、 X Q X ^S、X ^Q XS、XQ。该过程正式描述如下:
-
X i , j S = f θ ( x i , j S ) ∈ R h × w × c , X Q = f θ ( x Q ) ∈ R h × w × c , ( 1 ) X ^S _{i,j} = fθ (x ^S _{i,j})∈ \R ^{h×w×c} ,\\ X^ Q = fθ (x^ Q ) ∈ \R ^{h×w×c} , (1) Xi,jS=fθ(xi,jS)∈Rh×w×c,XQ=fθ(xQ)∈Rh×w×c,(1)
-
其中 x i , j S x^ S _{i,j} xi,jS 表示支持中第 i 个类的第 j 个实例,x Q是查询实例,h,w,c分别表示高度、宽度和通道的数量。然后,我们将X S的特征输入到任务表示生成模块(TRGM)中,如图4所示。
-
-
图4。任务表示生成模块(TRGM)。
-
-
第一个挑战是如何准确地获得每个特定任务的任务表示。在之前的研究中,流行的方法通过压缩任务(即一集)中的所有支持图像来表示当前任务。然而,通过这种方法获得的任务信息缺乏准确性。为了获取当前任务的更准确的任务信息,我们建议基于当前的局部特征信息学习任务表示。
-
为了尽量减少类间相似性的负面影响,我们认为从更微妙的特征中提取任务信息是有益的。因此,我们首先捕获当前任务(一集)中所有支持样本的局部特征信息,如方程式2所示。
-
F S = R ( X S ) ∈ R ( N ⋅ M ⋅ h ⋅ w ) × c F ^S = R(X ^S ) ∈ \R ^{(N·M·h·w)×c} FS=R(XS)∈R(N⋅M⋅h⋅w)×c
-
其中R是重塑操作。之后,为了从大量的局部特征信息中准确地获得当前任务的表示,网络采用了动态学习方法。具体来说,我们引入了多个任务令牌 T ∈ R k × c T∈\R ^{k×c} T∈Rk×c,其中k是任务令牌的数量。任务令牌T和所有支持令牌 F S F ^S FS 被连接成一个序列,并被输入到传统的视觉TRANSFORMER中。通过采用Transformer的注意力机制,k个初始化的任务令牌通过与支持样本的局部特征信息的连续交互,自适应地识别最能代表当前任务的特征信息。如方程式3所示:
-
Z ^ = S o f t m a x ( ( ( F S ∣ ∣ T ) W ϕ Q ) ( ( F S ∣ ∣ T ) W ϕ K ) T d k ) ( F S ∣ ∣ T ) W ϕ V , \hat{Z}=\mathrm{S o f t m a x} \left( \frac{( ( F^{S} | | T ) W_{\phi}^{Q} ) ( ( F^{S} | | T ) W_{\phi}^{K} )^{T}} {\sqrt{d_{k}}} \right) ( F^{S} | | T ) W_{\phi}^{V}, Z^=Softmax(dk((FS∣∣T)WϕQ)((FS∣∣T)WϕK)T)(FS∣∣T)WϕV,
-
其中||是连接操作。 W φ Q 、 W φ K 和 W φ V W ^Q_φ、W ^K_φ 和 W ^V_φ WφQ、WφK和WφV是一组c×c大小的可学习权重参数。
-
-
然后,我们通过层归一化(LN)和多层感知器(MLP)计算获得的任务表示,以使获得的任务表达更加准确。如方程式4所示:
-
z ^ = M L P ( L N ( z ^ ) ) ∈ R r ∗ c \hat z=MLP(LN(\hat z))\in\R^{r*c} z^=MLP(LN(z^))∈Rr∗c
-
其中r=k+N·M·h·w。最后,我们从Zˆ中选择前k个输出,得到k个任务表示,记为 T ^ ∈ R k × c \hat T∈\R ^{k×c} T^∈Rk×c。
-
-
与之前通过全局平均池获得单个任务表示的方法相比,TRGM允许网络通过任务内所有局部特征的交互,从多个来源自适应地获取表示当前任务的特征信息,从而获得更准确的任务表示。此外,虽然TRGM生成了一些冗余计算,但由于初始化的k数量较少,其数量可以忽略不计。因此,没有进行额外的数据处理。
D. Channel Region-Aware Module
-
在获得当前任务的表示后,通过任务表示准确捕获图像的图像任务级信息变得至关重要。换句话说,网络应该能够准确地捕获与特定任务相关的特征区域。最近,已经证明通过通道权重有效地定位目标的歧视性细节对FSFG图像任务是有益的。受此启发,我们将每个特征通道与任务表示相关联。我们将每个特征通道视为一个单独的函数,并通过每个通道函数映射任务表示。此过程强调每个通道上与当前任务相关的特征区域。因此,当获取的任务表示遍历不同的通道函数时,这些函数自然会优先考虑与当前任务最相关的图像区域。该过程如图5所示。
-
-
图5。信道区域感知模块(CRAM)。为了更直观,图中只使用了一个查询特征图。
-
-
我们首先采用过滤过程,将多个任务表示与特征图 X 的每个通道进行比较,旨在识别每个通道特征图中与当前任务最相关的区域。我们利用信道平均(CA)操作和 Sigmoid 函数来组合来自所有信道特征图的选定区域的信息,从而获得与当前任务最相关的区域的多个权重图。如方程式5所示:
-
W i , j S = σ ( 1 c ∑ v = 1 c ( X i , j S ⊙ T ^ ) v ) ∈ R k × h × w , W Q = σ ( 1 c ∑ v = 1 c ( X Q ⊙ T ^ ) v ) ∈ R k × h × w . (5) \begin{aligned} {{{}}} & {{} {{} {{} W_{i, j}^{S}=\sigma( {\frac{1} {c}} \sum_{v=1}^{c} ( X_{i, j}^{S} \odot{\hat{T}} )_{v} ) \in\mathbb{R}^{k \times h \times w},}}} \\ {{{}}} & {{} {{{} {} {{} W^{Q}=\sigma( {\frac{1} {c}} \sum_{v=1}^{c} ( X^{Q} \odot{\hat{T}} )_{v} ) \in\mathbb{R}^{k \times h \times w}.}}}} \\ \end{aligned} \tag{5} Wi,jS=σ(c1v=1∑c(Xi,jS⊙T^)v)∈Rk×h×w,WQ=σ(c1v=1∑c(XQ⊙T^)v)∈Rk×h×w.(5)
-
其中⊙表示哈达玛积,σ表示Sigmoid函数,v表示特征的第v个通道。请注意,在计算之前,需要将两个输入重新整形为相同的维度,其他计算类似。
-
-
我们利用权重图作为特征图的掩码,使模型能够为支持图像和查询图像导出多个任务感知特征。最后,我们利用任务平均(TA)操作,该操作整合了所有任务感知特征的信息,以获得全面的表示。如方程式6所示:
-
P i , j S = 1 k ∑ t = 1 k ( W i , j S ⊙ X i , j S ) t ∈ R h × w × c , P Q = 1 k ∑ t = 1 k ( W Q ⊙ X Q ) t ∈ R h × w × c , (6) P_{i, j}^{S}=\frac{1} {k} \sum_{t=1}^{k} ( W_{i, j}^{S} \odot X_{i, j}^{S} )_{t} \in\mathbb{R}^{h \times w \times c}, \, P^{Q}=\frac{1} {k} \sum_{t=1}^{k} ( W^{Q} \odot X^{Q} )_{t} \in\mathbb{R}^{h \times w \times c}, \tag{6} Pi,jS=k1t=1∑k(Wi,jS⊙Xi,jS)t∈Rh×w×c,PQ=k1t=1∑k(WQ⊙XQ)t∈Rh×w×c,(6)
-
其中t表示特征的第t个任务表示。
-
E. Refining Filter Module
-
由于细粒度图像的固有特性,表现出较低的类间差异和较大的类内差异,因此认为所获得的任务感知特征本身是细粒度图像最具辨别力的特征是不合适的。因此,我们对获取的任务感知特征进行了更精细的特征提取,以获得更具辨别力的特征表示。如图6所示。
-
-
图6。精炼过滤器模块(RFM)。
-
-
为了减少背景噪声的影响并从任务感知特征中提取细微的区别特征,我们首先初始化n个可学习的记号化器函数 [ h o ( ⋅ ) ] o n = 1 [ho(·)]^ n_ o=1 [ho(⋅)]on=1。在本文中,记号赋予器由4个卷积层和一个sigmoid函数实现。理论上,也可以使用其他网络。然后,我们使用这些函数来构建一个空间注意机制,它能够在任务感知特征中自适应地选择有区别的区域。最后,为了进一步降低模型的计算复杂度,我们应用全局平均池来压缩特征的空间维度。因此,任务感知特征的区分区域由几个重要的视觉表征来表示。具体而言,对于输入的任务感知特征,该模型首先通过记号化器函数ho生成空间权重图 R h × w × 1 R ^{h×w×1} Rh×w×1,每个空间权重图表示从任务感知特征激活的重要区域,计算如下:
-
( W ‾ i , j S ) o = σ ( h o ( P i , j S ) ) ∈ R h × w × 1 . (7) ( \overline{{{W}}}_{i, j}^{S} )_{o}=\sigma\left( h_{o} \left( P_{i, j}^{S} \right) \right) \in\mathbb{R}^{h \times w \times1}. \tag{7} (Wi,jS)o=σ(ho(Pi,jS))∈Rh×w×1.(7)
-
-
然后,我们利用获得的权重图作为任务感知特征本身的掩模,以强调任务感知特征中的关键区域。最后,使用全局平均池来减少特征的维数,如下所述:
-
( P ^ i , j S ) o = G A P ( P i , j S ⊙ ( W ‾ i , j S ) o ) ∈ R 1 × 1 × c , (8) ( \hat{P}_{i, j}^{S} )_{o}=G A P ( P_{i, j}^{S} \odot( \overline{{{{W}}}}_{i, j}^{S} )_{o} ) \in R^{1 \times1 \times c}, \tag{8} (P^i,jS)o=GAP(Pi,jS⊙(Wi,jS)o)∈R1×1×c,(8)
-
其中GAP表示空间全局平均汇集操作。n个记号赋予器函数的结果被聚合以形成输出张量: P ^ i , j S ∈ R n × c \hat P ^S _{i,j} ∈ \R ^{n×c} P^i,jS∈Rn×c。通过相同的操作,我们可以获得查询集P Q上的特征。
-
-
RFM的本质是元素级的空间注意。CRAM旨在强化与当前任务相关的特征区域。RFM通过元素级空间注意机制进一步增强这些特征区域中的重要表征,并提取这些表征。因此,我们获得的表征不再是简单信息的固定划分,而是适应特定任务的空间选择。这使得所提出的方法比以前获得区别特征的方法更简单和更有效。
F. Objective Function
-
我们的模型是一个端到端的方法,有两个损失函数。具体来说,我们在训练过程中引入了一个全连接层作为分类器。随后,我们利用这个分类器来预测整个基本数据集的相应类别。我们使用交叉熵损失函数来实现这一点,如下所示:
-
L a = arg min θ , ω ∑ i = 1 D b a s e L C E ( C ω ( f θ ( x i ) ) , y i ) , (9) {\cal L}_{a}=\underset{\theta, \omega} {\operatorname{a r g \, m i n}} \sum_{i=1}^{D_{b a s e}} {\cal L}^{\mathrm{C E}} \left( C_{\omega} \left( f_{\theta} \left( x_{i} \right) \right), y_{i} \right), \tag{9} La=θ,ωargmini=1∑DbaseLCE(Cω(fθ(xi)),yi),(9)
-
其中 i 表示Dbase的第 i 个样本,Cω表示全连接层。我们用它作为辅助损失函数来指导特征提取层。此外,在通过余弦相似度计算 p ^ S \hat p^S p^S 和 p ^ Q \hat p^Q p^Q 之间的相似度之后,我们使用一个度量损失来引导模型将一个查询嵌入映射到一个同类的支持嵌入附近。对此的描述如下:
-
L m = − log exp ( cos ( ( P ^ S ) ( j ) , ( P ^ Q ) ( j ) ) ) ∑ j ′ = 1 N exp ( cos ( ( P ^ S ) ( j ′ ) , ( P ^ Q ) ( j ′ ) ) ) , (10) \mathcal{L}_{\mathrm{m}}=-\operatorname{l o g} \frac{\operatorname{e x p} \left( \operatorname{c o s} \left( ( \hat{\mathbf{P}}^{S} )^{( j )}, ( \hat{\mathbf{P}}^{Q} )^{( j )} \right) \right)} {\sum_{j^{\prime}=1}^{N} \operatorname{e x p} \left( \operatorname{c o s} \left( ( \hat{\mathbf{P}}^{S} )^{( j^{\prime} )}, ( \hat{\mathbf{P}}^{Q} )^{( j^{\prime} )} \right) \right)}, \tag{10} Lm=−log∑j′=1Nexp(cos((P^S)(j′),(P^Q)(j′)))exp(cos((P^S)(j),(P^Q)(j))),(10)
-
其中cos(·)表示余弦相似度。总损失是上述两个子损失的总和:
-
L = L a + λ L m . ( 11 ) L = L_a + λL_m. (11) L=La+λLm.(11)
-
其中λ是平衡两个损失项的超参数。在测试阶段,不再使用分类器,只计算支持样本和查询样本之间的相似度。
-
-
损失设计方面,使用了交叉熵损失和度量损失的组合,前者辅助特征提取,后者引导样本相似性学习。需要分析这种损失函数设计如何适配小样本细粒度分类任务,平衡特征学习和度量学习的作用。
-
交叉熵损失(辅助损失):对基类数据(Base Set)进行分类训练,引导嵌入模块学习通用视觉特征(如边缘、纹理)。公式: L a = min ∑ L C E ( C ω ( f θ ( x i ) ) , y i ) \mathcal{L}_a = \min \sum \mathcal{L}^{CE}(C_\omega(f_\theta(x_i)), y_i) La=min∑LCE(Cω(fθ(xi)),yi),其中 C ω C_\omega Cω 为全连接层, f θ f_\theta fθ 为嵌入模块。
-
度量损失(主损失):在小样本任务中,计算支持集与查询集特征的余弦相似度,迫使同类样本特征接近,异类远离。公式: L m = − log exp ( cos ( P ^ S , P ^ Q ) ) ∑ exp ( cos ( P ^ S , P ^ Q ) ) \mathcal{L}_m = -\log \frac{\exp(\cos(\hat{P}^S, \hat{P}^Q))}{\sum \exp(\cos(\hat{P}^S, \hat{P}^Q))} Lm=−log∑exp(cos(P^S,P^Q))exp(cos(P^S,P^Q)),其中 P ^ S , P ^ Q \hat{P}^S, \hat{P}^Q P^S,P^Q 为精炼后的特征。
-
平衡通用与特定特征:交叉熵损失确保模型具备基础视觉理解能力,度量损失则针对小样本场景,聚焦任务特定的细粒度差异。噪声鲁棒性:RFM 生成的关键令牌减少背景噪声对相似度计算的影响,使度量损失更关注判别区域。
EXPERIMENTAL RESULTS AND ANALYSIS
A. Datasets
- 我们在四个细粒度的图像数据集上评估了所提出的方法,每个数据集按比例分成基本集、验证集和新集。每个数据集的数据分割细节可在表1中找到。与之前的研究一致,所有图像的大小都调整为84×84。数据集的简要概述如下:
-
Caltech-200-2011 (CUB) 是一个经典的细粒度鸟类图像数据集,包含200个类和11788张图像。根据,我们根据手动标记的边界框裁剪每个图像。
-
斯坦福狗数据集(Dogs) 是一个非常具有挑战性的细粒度图像数据集,它包含了来自世界各地120 类品种的20580张狗的图像。
-
Stanford Cars-196 (Cars) 也是FSFG图像分类中常用的数据集。它包含16185幅图像和196个车辆类别。
-
Oxford 102 Flowers (Flowers) 是一个包含总共102个类和8189个图像的花卉数据集。
-
-
表1实验数据集概述
-
B. Experimental Settings
-
1)架构:在我们的研究中,我们利用了两个广泛认可的卷积神经网络作为嵌入模型:Conv-4和ResNet-12 。对于Conv-4模型,我们修改了网络架构,删除了最后一个池层。这种改变使得最终输出的特征图具有64 × 10 × 10的维数。在ResNet-12的情况下,最终输出的特征图的大小为64 × 5 × 5。基于Conv-4主干的ATR-Net有0.4M个参数,计算成本为0.6G FLOPs,而基于ResNet-12主干的ATR-Net有22M个参数,计算成本为5.7G FLOPs。
-
2)训练设置:对于我们的方法,我们应用了标准的数据扩充技术,包括随机裁剪、水平翻转和颜色抖动。这些实验包括训练总共120个时期的模型。我们使用随机梯度下降(SGD)作为优化器,动量为0.9,权重衰减为5e-4。该模型的初始学习率被设置为0.1,在80和100个时期后降低10倍。关于任务令牌(k)和令牌化器函数(n)的参数,对于Conv-4模型,我们将它们设置为k = 5和n = 64,对于ResNet-12模型,将它们设置为k = 4和n = 16。对于超参数λ,它在实验中被设置为1.5。我们用PyTorch框架实现了我们提出的方法。
-
3)测试设置:我们使用元学习技术来训练和测试整个模型。具体来说,我们通过5路1样本或5路5样本集来训练网络。为了评估我们的结果,我们在每集每班测试了15个查询图像。报道的平均准确度基于2000次测试,我们提供了这些结果及其95%的置信区间。
-
-
表II基于CONV-4主干网的细粒度数据集(CUB、DOGS、CARS)上的5向少镜头分类精度。标有BY♢的方法表示使用与我们相同的架构的重新实现
-
C. Comparison With State-of-the-Arts
- 为了验证我们提出的方法的有效性,我们在上述四个细粒度数据集上进行了实验。我们将我们的方法与两类基线进行了比较。(1)第一类基线是标准的少样本学习方法,包括MatchingNet ,ProtoNet ,RelationNet ,DN4 ,ReNet ,MixFSL ,FRN 。(2)第二类基线是FSFG学习的最新方法,包括BSNet ,TOAN ,AGPF ,OLSA ,LCCRN ,TDM ,BiFRN ,BSFA 。表二、表三和表四显示了业绩评价。对于每个数据集,我们报告5路1次和5路5次分类的结果。观察结果可总结如下:
-
提出的ATR-Net在所有基准数据集上取得了最先进的结果。具体来说,不管主干是Conv-4还是ResNet-12,ATR-Net 在CUB、Cars和Flowers数据集上显示出比最先进的 FS 方法FRN大约2%到5%的改进。在Dogs数据集上,ATR-Net显示了更大的优势。当使用ResNet-12作为主干时,ATR-Net在5路1次设定中实现了75.68%的准确率,优于FRN 9.87%。
-
此外,ATR-Net在所有数据集上也比先进的FSFG方法BiFRN有很大的优势。特别是在Cars数据集上,我们的方法在5 样本上可以达到95.24%的最佳性能。与动物和植物不同,汽车是刚性物体,相应类别内的差异并不显著。这进一步证实了我们的方法在FSFG的重要性。
-
另一方面,尽管在CUB和Cars数据集上,ATR-Net在5路1次场景中的性能略低于BSFA,但在大多数情况下,特别是在Dogs数据集上,ATR-Net仍然表现出明显的优势。此外,我们的结果的置信区间宽度明显比大多数最先进的方法窄,表明ATR-Net能够提供更准确的预测。
-
-
表III基于RESNET-12主干网的细粒度数据集(CUB、DOGS、CARS)上的5向少数镜头分类精度
-
-
表四:FLOWERS数据集上的5路少样本分类精度。顶部模块使用CONV-4主干网,底部模块使用RESNET-12主干网
-
D. Ablation Study
-
TRGM、CRAM和RFM的影响:为了进一步证明每个模型组件在准确性和有效性方面的贡献,我们使用Conv-4作为主干网络在四个数据集上进行消融实验。我们在没有任何模块插入作为基线的情况下评估网络的准确性。随后,我们按顺序将每个模块插入网络。我们从相关定义的解释开始。
- 主干:这表示在没有任何模块插入的情况下评估网络的准确性,我们也使用这个结果作为基线。
- 主干+ TRGM:这意味着只插入TRGM模块。应当指出,TRGM的主要作用是获得当前任务的表示。如果没有CRAM,它就不能基于这种表示来识别任务相关的特征区域。因此,该模型的性能与仅使用主干网时相同,因此未在表中显示。
- 主干+ CRAM:这意味着只插入CRAM模块。此时,TRGM使用随机初始化,并且不从支持图像获得关于任务表示的信息。
- 主干+ RFM:这意味着只插入RFM模块。
- 主干+ TRGM + CRAM:这意味着TRGM和CRAM模块一起插入。
- 主干+ TRGM + RFM:这意味着 TRGM 和 RFM 模块被一起插入。与上述原因一致,没有CRAM,TRGM就不能产生效用。在这一点上,模型性能相当于主干网+ RFM,因此不在表中显示。
- 主干+ CRAM + RFM:这意味着 CRAM 和 RF 模块是一起插入的。此时,TRGM使用随机初始化,并且不从支持图像获得关于任务表示的信息。
- ATR-Net:这意味着所有的模块都被接入网络
-
TRGM 的作用:随机初始化任务令牌(不使用支持集交互)时,性能接近基线,证明动态生成任务表示的必要性。
-
RFM 的作用:仅用 CRAM(粗粒度区域)时,仍受背景噪声影响;加入 RFM 后,实例级特征更纯净,准确率提升 2-3%。
-
任务令牌数(k):当 k=5 时性能最佳,过少(k=1)无法捕捉多维度任务信息,过多(k>5)导致冗余和过拟合。
-
令牌化函数数(n):n=64 时平衡判别力与计算效率,过多会引入重复的空间注意力图,增加计算量但收益下降。任务令牌数 k 建议设为 4-6(平衡表达能力与计算量),令牌化函数数 n 根据 backbone 复杂度调整(Conv-4 用 64,ResNet-12 用 16)。
-
损失权重(λ):λ=1.5 时最优,过小(λ<1)导致度量损失不足,过大(λ>2)使模型过度依赖支持集,泛化性下降。
-
实验结果在表v中示出。从该表中,我们可以得出以下三个结论:
-
(1)随着模块被逐步插入,模型的性能逐渐提高,当所有模块都存在时达到其最大值。这证明了我们提出的方法的有效性。
-
(2)从(b)和(d)的比较中,很明显,当任务被随机初始化而不使用支持信息来获得当前任务时,模型的性能基本上与仅使用主干时相同。这证明了所提出的TRGM模块有益于该模型。
-
(3)仅包括TRGM-CRAM模块导致模型性能的适度改善。然而,当RFM也被合并时,在性能上有进一步的增强。这个结果强调了仅仅激活FSFG图像分类的任务相关区域的不足。值得注意的是,当TRGM、CRAM和RFM模块一起使用时,模型性能的改善最大,突出了它们的基本和协同作用。
-
-
表V 仅使用TRGM、CRAM或RFM的消融研究。我们在四个细粒度数据集上显示了95%置信区间的分类精度
-
-
混淆矩阵:此外,我们给出了CUB测试集上每个模型的混淆矩阵,如图7所示。我们可以观察到,在5路1次设置和5路5次设置中,与类之间的基线混乱相比,我们提出的模型显示出显著的改进。特别地,在5次设置中具有基线的类别1和类别5的精度分别是0.55和0.47,而具有ATR-Net模型的类别1和类别5的精度分别是0.70和0.75。
-
-
图7。CUB数据集上不同模块的混淆矩阵。
-
-
样本数量的影响:在这一部分,我们比较了ATR-Net、经典的少样本方法FRN和少镜头细粒度方法BiFRN在训练过程中使用不同镜头的分类性能。我们将主干设置为Conv-4,将路的数量设置为5。样本次数的结果记录在表VI中。
-
-
表VI在CONV-4主干下对于幼崽、DOGS、汽车和花朵数据集的不同拍摄数量的分类准确度
-
-
我们可以清楚地观察到,ATR-Net在四个细粒度数据集上对所有数量的镜头实现了最佳性能。这表明ATR-Net对不同拍摄次数的变化具有更好的鲁棒性。此外,我们还注意到,当镜头数为1时,三种方法之间的分类精度差距较大,这证明ATR-Net在更具挑战性的任务上更具优势。
E. Hyperparameter Analysis
-
任务令牌的数量(k)和令牌化器函数的数量(n)的影响:为了验证任务令牌的数量(k)和令牌化器函数的数量(n)对模型性能的影响,我们在CUB和Dogs数据集上进行实验验证,并使用Conv-4网络作为主干。结果如图8所示。基于所获得的结果,我们可以得出以下结论:
-
如图8 (a)和©所示,当任务令牌的数量(k)被设置为1时,该模型表现出低性能。随着任务表示的数量增加,观察到整体模型性能的上升趋势。因此,可以推断,学习多个任务表征使得模型能够获得更丰富的任务信息。否则,当任务表征的数量超过5或6时,模型的性能逐渐达到饱和并趋于降低,这表明获取过多数量的任务表征并不能产生额外的任务信息。
-
如图8 (b)和(d)的右图所示,随着表征器函数的数量(n)增加,模型的性能开始逐渐上升,并在n为64时达到峰值。这表明RFM在一定程度上可以有效地识别任务特征中的区分区域。然而,当n的数量大于64时,模型性能开始下降。原因可能是由于标记器函数选择了大量相同的空间语义信息,导致模型聚焦于图像的某些区域,这进一步影响了模型的泛化性能。
-
实验结果表明,k和n在不同的数据集上表现出相似的性能趋势。具体地,当k固定且n被设置为64时,以及同样地,当n固定且k被设置为5时,该模型达到最佳性能。这些发现表明,所提出的方法对不同数据集之间的参数变化并不高度敏感。
-
-
图8。在CUB和dogs数据集上关于不同数量的任务标记(k)和标记化器函数(n)的性能。
-
-
比例因子(λ)的影响:如等式(11)所述,λ是平衡基本损失和度量损失的比例因子,在实验中,我们使用Conv-4作为主干,并在CUB数据集上使用不同的λ值时验证模型的性能,如图9所示。
-
-
图9。不同尺度因子λ值下的分类精度。
-
-
我们可以看到,随着λ的增加,总体趋势是性能先上升后逐渐下降,当λ为1.5时,模型获得最佳性能。造成这种现象的原因是,当λ较小时,基损失对度规损失有正的作用,但当λ相当大时,度规损失成为损失函数损失的主要部分,降低了基损失的影响。因此,在我们的模型中,比例因子λ被设置为1.5。
F. The Classification Performance of ATR-Net on Coarse-Grained Images
-
我们还在粗粒度数据集mini-ImageNet上比较了分类精度。使用与细粒度数据集相同的训练策略,我们将我们的方法与少样本方法 FRN 和少样本细粒度方法双FRN进行了比较。结果如表VII所示。
-
-
表VII使用CONV-4和 RESNET-12主干的迷你图像网上1次和5次拍摄准确度(%)的比较
-
-
我们可以观察到,当使用Conv-4时,ATR-Net比FRN和BiFRN方法保持优势。然而,对于ResNet-12,ATR-Net的精确度略低于FRN方法。一个可能的原因是,迷你ImageNet图像比细粒度图像表现出更大的类间差异和更复杂的背景特征。当使用Conv-4网络处理特征时,由于难以处理复杂的任务,网络在获得的特征中包含更多的噪声信息。ATRNet不仅准确地捕捉当前任务的表征,而且使用较少的表征来表示当前特征信息,这大大降低了噪声信息的影响。因此,ATR-Net在Conv-4上表现出更好的性能。当主干是ResNet-12时,噪声信息相对减少,与更关注全局特征的FRN方法相比,削弱了ATR-Net的优势。
G. Feature Visualization
-
为了更全面地说明所提出方法的有效性,我们采用 CAM 来突出显示每个模型的关键感兴趣区域。在图10中,我们从CUB、Dogs和Cars数据集中随机选择两个原始图像,并将原始图像的大小调整为与嵌入层的输出相同的大小,并将调整大小的原始图像与CAM的输出进行比较。
-
-
图10。不同模型在CUB,dogs,cars数据集上区分区域激活抑制的可视化比较。区域越红,阶级歧视就越严重。
-
-
如图10所示,我们可视化了由支持和查询图像的基线、TRGM-CRAM和ATR-Net模型激活的区域。我们可以观察到,由于缺乏任务意识,基线激活了包含大量不相关信息的特征区域,例如“树枝”。当通过TRGMCRAM时,我们的模型激活与当前任务相关的特征区域,但是并不是所有获得的区域对于FSFG分类都是有效的。经过RFM模块后,我们的模型获得了更准确的判别特征区域。这表明ATR-Net可以基于当前任务信息获得更有效的特征表示。
DISCUSSION
-
为了获得更精确的任务表示和自适应去除图像背景噪声,我们提出了自适应任务感知精炼网络。然而,网络可能仍然缺乏灵活性。具体来说,当任务改变时,网络需要额外的时间来确定适当数量的令牌和令牌化器。如果标记或标记化器的数量更多或更少,模型可能实现更少的增益。在未来,探索自适应地确定模型中标记数量的方法将是必要的。
-
此外,虽然我们的方法取得了更好的性能,但我们发现ATR-Net比基线和其他比较方法具有更多的可学习参数,这是因为包含了任务标记和标记化器。当骨干网为Conv-4时,FRN和BIFRN的参数数量分别为0.23M和0.27M,而ATR-Net为0.44M。当骨干网为ResNet-12时,FRN和BIFRN的参数数量分别为12.42M和16.11M,ATR-Net为22.74M
CONCLUSION
-
在本文中,我们介绍了一种创新的自适应任务感知提炼网络(ATR-Net),专门用于少样本的细粒度图像分类。这项工作最重要的贡献在于它能够自适应地生成对应于当前任务的多种表示。这允许模型在各种任务之间平滑过渡,从而捕捉与每个任务的独特属性产生共鸣的更准确的特征表示。此外,我们创造性地将这些任务表征与单个图像通道相关联,以增强模型对图像任务级信息的访问,从而提高模型在识别特定任务特征区域中的精度。最后,我们设计了一个包含注意机制的细化过滤器模块,该模块使用多个具有空间注意的标记化器来为每个输入计算多个空间权重图,这些图可以学习特征中的重要区分区域。这使得网络能够自适应地获取图像中微妙的区别特征,并仅使用少数重要的视觉表征来表示它们。大量的实验评估表明,我们提出的网络在四个细粒度的图像数据集上一致地实现了优越的性能。
-
ATR-Net 主要由编码器(ConvNet 或 ResNet)、TACF 模块和 FC 层组成。编码器负责特征提取,TACF 处理任务相关的注意力机制。数据准备在 dataset_builder 中,支持多种数据集,如 miniimagenet、cub 等,使用 CategoriesSampler 生成小样本任务。Attention 类在 TACF.py 中,实现多头注意力,带有层缩放。Conv2d_fw 和 BatchNorm2d_fw 用于 MAML 风格的快速权重更新,但项目中可能主要用 ConvNet 作为编码器。TACF 模块的 forward 方法处理支持集和查询集的特征,生成任务特定区域。
-
数据流转从 dataset_builder 加载数据集,通过 DataLoader 和 CategoriesSampler 生成批次,输入到编码器,然后 TACF 处理,最后通过度量函数计算相似度。训练时结合 epi_loss 和辅助损失,使用 SGD 优化。小样本学习中的元学习,注意力机制用于任务特定特征提取,度量学习(余弦相似度)比较样本。
-
# dataset_builder:根据参数构建数据集 def dataset_builder(args): set_seed(args.seed) # 固定随机种子确保 reproducibility if args.dataset == 'miniimagenet': from models.dataloader.mini_imagenet import MiniImageNet as Dataset # ... 其他数据集加载逻辑 ... return Dataset # 训练数据加载(task_token_train.py) Dataset = dataset_builder(args) trainset = Dataset('train', args) # CategoriesSampler:生成小样本任务(N-way K-shot) train_sampler = CategoriesSampler( trainset.label, # 样本标签列表 len(trainset.data) // args.batch, # 每个epoch的迭代次数 args.way, # 类别数(N-way) args.shot + args.query # 每类样本数(K-shot + Query) ) train_loader = DataLoader(dataset=trainset, batch_sampler=train_sampler, num_workers=6)
-
采用
CategoriesSampler
按类别采样,确保每个批次包含args.way
个类别,每类args.shot
个支持样本和args.query
个查询样本。数据集加载 → 任务采样 → 输入编码器提取特征 → TACF 模块处理 → 度量计算与损失优化。 -
# ConvNet:4层卷积编码器(非ResNet版本) class ConvNet(nn.Module): def __init__(self, depth=4): super().__init__() trunk = [] for i in range(depth): indim = 3 if i == 0 else 64 # 前两层包含MaxPooling,后两层无(适应Relation Net结构) B = ConvBlock(indim, 64, pool=(i in [0, 1]), padding=0 if i in [0, 1] else 1) trunk.append(B) self.trunk = nn.Sequential(*trunk) def forward(self, x): # 分层输出特征(可用于中间层监控) out_0 = self.trunk[0](x) out_1 = self.trunk[1](out_0) out_2 = self.trunk[2](out_1) out_3 = self.trunk[3](out_2) return out_3 # 输出维度:(B, 64, H, W)
-
前两层卷积后接 MaxPooling 降采样,后两层仅卷积 + BN+ReLU,保留空间分辨率用于注意力区域定位。
backbone_res12.py
中的ResNet
使用 4 个 BasicBlock,每层包含 3×3 卷积和残差连接,最终输出维度 640(适用于更大图像输入)。
-
-
ATR-Net(Adaptive Task-Related Feature Learning)是针对小样本学习(Few-Shot Learning)的模型,核心思想是通过任务相关的注意力机制动态提取支持集(Support Set)和查询集(Query Set)的关键区域特征,增强类别判别能力。引入 TACF(Task-Adaptive Context Fusion)模块,利用多头注意力和任务描述子(Task Descriptor)实现跨样本的上下文融合,生成任务特定的区域特征。通过余弦相似度计算样本特征的相似性,结合辅助损失(Auxiliary Loss)提升模型泛化能力。
-
# 多头注意力层(带层缩放) class Attention(nn.Module): def __init__(self, dim, heads=8, dropout=0., temperature=1., layer_scale_init=-1): super().__init__() self.heads = heads self.dim_head = dim // heads self.scale = self.dim_head ** -0.5 # 缩放因子 self.temperature = temperature # 注意力温度参数 self.to_qkv = nn.Linear(dim, dim * 3) # QKV线性映射 self.to_out = nn.Sequential(nn.Linear(dim, dim), nn.Dropout(dropout)) # 层缩放参数(初始化为layer_scale_init或1) self.layer_scale = nn.Parameter(torch.ones(1, 1, dim) * layer_scale_init) if layer_scale_init > 0 else None def forward(self, x): b, n, _, h = *x.shape, self.heads qkv = self.to_qkv(x).chunk(3, dim=-1) # 重组维度为(b, heads, num_tokens, dim_head) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), qkv) dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale # 点积计算 attn = dots.softmax(dim=-1) / self.temperature # 温度缩放后的softmax out = torch.einsum('bhij,bhjd->bhid', attn, v) # 注意力聚合 out = rearrange(out, 'b h n d -> b n (h d)') # 合并多头维度 # 残差连接与层缩放 return self.to_out(out) + x if self.layer_scale is None else self.to_out(out).mul_(self.layer_scale) + x
-
通过多头注意力捕捉不同子空间的特征依赖,
layer_scale
动态调整残差连接的贡献,避免梯度消失。通过 Transformer 编码器将支持集特征映射为任务描述子,用于指导查询集的注意力区域选择。
-
-
多头注意力机制(TACF 中的
Attention
类): Attention ( Q , K , V ) = Softmax ( Q K T d k ) V \text{Attention}(Q, K, V) = \text{Softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V Attention(Q,K,V)=Softmax(dkQKT)V 。 其中,通过to_qkv
线性层将输入映射为 Query、Key、Value,利用多头划分(Heads)并行计算注意力,增强特征表示能力。在残差连接中引入可学习的缩放因子layer_scale
,缓解深层网络梯度消失问题: Out = LayerScale × to_out ( X ) + X \text{Out} = \text{LayerScale} \times \text{to\_out}(X) + X Out=LayerScale×to_out(X)+X。计算支持集与查询集特征的余弦相似度作为分类依据:
( logits = cosine ( f support , f query ) × scale \text{logits} = \text{cosine}(f_{\text{support}}, f_{\text{query}}) \times \text{scale} logits=cosine(fsupport,fquery)×scale)。 -
TACF
(Task Adaptive Context Fusion)模块是论文中提出的一个关键组件,用于自适应地融合任务相关的上下文信息,以提高小样本学习的性能。TACF
模块的核心思想是通过生成任务描述符(task descriptors)来捕捉任务的上下文信息,并利用这些描述符自适应地选择和增强与任务相关的特征区域。-
LayerNorm:对输入特征进行层归一化,有助于稳定训练过程。
-
class LayerNorm(nn.Module): def __init__(self, dim, eps=1e-6): super().__init__() self.eps = eps self.g = nn.Parameter(torch.ones(1, dim, 1, 1)) self.b = nn.Parameter(torch.zeros(1, dim, 1, 1)) def forward(self, x): var = torch.var(x, dim=1, unbiased=False, keepdim=True) mean = torch.mean(x, dim=1, keepdim=True) return (x - mean) / (var + self.eps).sqrt() * self.g + self.b class PreNorm(nn.Module): def __init__(self, dim, fn): super().__init__() self.norm = nn.LayerNorm(dim) self.fn = fn def forward(self, x, **kwargs): return self.fn(self.norm(x), **kwargs)
-
MLP:用于对特征进行非线性变换。
-
class MLP(nn.Module): def __init__(self, dim, hidden_dim, dropout=0., layer_scale_init=-1): super().__init__() self.layer_scale_init = layer_scale_init self.net = nn.Sequential( nn.Linear(dim, hidden_dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(hidden_dim, dim), nn.Dropout(dropout) ) if self.layer_scale_init > 0: self.layer_scale = nn.Parameter(torch.ones(1, 1, dim) * self.layer_scale_init) else: self.layer_scale = None def forward(self, x): return self.net(x) + x if self.layer_scale is None else self.net(x) * self.layer_scale + x
-
Attention:实现了多头自注意力机制,用于捕捉特征之间的依赖关系。
-
class Attention(nn.Module): def __init__(self, dim, heads=8, dropout=0., temperature=1., layer_scale_init=-1): super().__init__() self.heads = heads self.dim_head = dim // heads self.scale = self.dim_head ** -0.5 self.temperature = temperature self.layer_scale_init = layer_scale_init self.to_qkv = nn.Linear(dim, dim * 3) self.to_out = nn.Sequential( nn.Linear(dim, dim), nn.Dropout(dropout) ) if self.layer_scale_init > 0: self.layer_scale = nn.Parameter(torch.ones(1, 1, dim) * self.layer_scale_init) else: self.layer_scale = None def forward(self, x): b, n, _, h = *x.shape, self.heads qkv = self.to_qkv(x).chunk(3, dim=-1) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), qkv) dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale attn = dots.softmax(dim=-1)/self.temperature out = torch.einsum('bhij,bhjd->bhid', attn, v) out = rearrange(out, 'b h n d -> b n (h d)') return self.to_out(out) + x if self.layer_scale is None else self.to_out(out).mul_(self.layer_scale) + x
-
Transformer:由多个自注意力层和 MLP 层组成,用于对特征进行深度变换。
-
class Transformer(nn.Module): def __init__(self, dim, depth, heads, mlp_expansion=4, dropout=0., temperature=1., layer_scale_init=-1): super().__init__() self.dim = dim self.depth = depth self.heads = heads self.mlp_dim = dim * mlp_expansion self.dropout = dropout self.temperature = temperature self.layer_scale_init = layer_scale_init self.layers = nn.ModuleList([]) for _ in range(depth): self.layers.append(nn.ModuleList([ PreNorm(dim, Attention(dim, heads=heads, dropout=dropout, temperature=temperature, layer_scale_init=layer_scale_init)), PreNorm(dim, MLP(dim, self.mlp_dim, dropout=dropout, layer_scale_init=layer_scale_init)) ])) def forward(self, x): for attn, mlp in self.layers: x = attn(x) x = mlp(x) return x
-
TaskDescriptorGenerator:从支持集特征中生成任务描述符。
-
class TaskDescriptorGenerator(nn.Module): def __init__(self, dim, depth, heads, num_descriptor, mlp_expansion=4, dropout=0., temperature=1., layer_scale_init=-1): super().__init__() self.dim = dim self.depth = depth self.heads = heads self.mlp_dim = dim * mlp_expansion self.dropout = dropout self.temperature = temperature self.layer_scale_init = layer_scale_init self.num_descriptor = num_descriptor self.task_descriptor = nn.Parameter(torch.randn(1, num_descriptor, dim)) self.transformer = Transformer(dim, depth, heads, mlp_expansion, dropout, temperature, layer_scale_init) def forward(self, support_feature): n, c, h, w = support_feature.shape support_feature = rearrange(support_feature, 'n c h w -> 1 (n h w) c') x = torch.cat([self.task_descriptor, support_feature], dim=1) x = self.transformer(x) task_descriptor = x[:, :self.num_descriptor, :] return task_descriptor
-
TaskSpecificRegionSelector:根据任务描述符选择与任务相关的特征区域。
-
class TaskSpecificRegionSelector(nn.Module): def forward(self, feature, key_channels): "feature: (B, C, H, W)" "key_channels: (1, M, C)" b, c, h, w = feature.shape m = key_channels.shape[1] feature = torch.reshape(feature, (b, 1, c, h, w)) key_channels = torch.reshape(key_channels, (1, m, c, 1, 1)) dot = feature * key_channels dot = torch.mean(dot, dim=2) dot = torch.sigmoid(dot) return dot
-
MAP:将选择的区域与原始特征进行融合,得到任务感知的特征。
-
class MAP(nn.Module): def __init__(self, in_dim, out_dim): super(MAP, self).__init__() def forward(self, task, feat): b, c, h, w = feat.shape m = task.shape[1] feature = torch.reshape(feat, (b, 1, c, h, w)) task = torch.reshape(task, (b, m, 1, h, w)) dot = feature * task + feature dot = torch.mean(dot, dim=1) return dot
-
RFM:对任务感知的特征进行进一步处理,生成最终的任务增强特征。
-
class RFM(nn.Module): def __init__(self, in_dim, num_token): super(RFM, self).__init__() self.in_dim = in_dim self.num_token = num_token self.selected_func = nn.Sequential( LayerNorm(self.in_dim), nn.Conv2d(self.in_dim, self.num_token, 3, 1, 1, bias=False), nn.ReLU(), nn.Conv2d(self.num_token, self.num_token, 3, 1, 1, bias=False), nn.ReLU(), nn.Conv2d(self.num_token, self.num_token, 3, 1, 1, bias=False), nn.ReLU(), nn.Conv2d(self.num_token, self.num_token, 3, 1, 1, bias=False), Rearrange('b n h w -> b n (h w)'), nn.Sigmoid() ) def forward(self, feature, task_specific_region): task = self.selected_func(task_specific_region) task = task[:, :, :, None] feat = rearrange(feature, 'b c h w -> b (h w) c') feat = feat[:, None, :, :] dot = feat * task dot = torch.mean(dot, dim=-2) dot = dot.permute(0, 2, 1) return dot
-
TACF:整合上述模块,完成任务自适应的上下文融合。
-
class TACF(nn.Module): def __init__(self, args, dim, descriptor_depth, heads, num_descriptor, num_token, mlp_expansion=4, dropout=0., temperature=1., layer_scale_init=-1): super().__init__() self.dim = dim self.descriptor_depth = descriptor_depth self.heads = heads self.num_descriptor = num_descriptor self.mlp_expansion = mlp_expansion self.dropout = dropout self.temperature = temperature self.layer_scale_init = layer_scale_init self.token = num_token self.args = args self.task_descriptor_generator = TaskDescriptorGenerator(dim, descriptor_depth, heads, num_descriptor, mlp_expansion, dropout, temperature, layer_scale_init) self.task_specific_region_selector = TaskSpecificRegionSelector() self.rfm = RFM(dim, self.token) self.map = MAP(num_descriptor, dim) def forward(self, support_feature, query_feature): task_descriptors = self.task_descriptor_generator(support_feature) task_specific_query_region = self.task_specific_region_selector(query_feature, task_descriptors) task_specific_support_region = self.task_specific_region_selector(support_feature, task_descriptors) task_att_qry = self.map(task_specific_query_region, query_feature) task_att_spt = self.map(task_specific_support_region, support_feature) task_specific_query_feature = self.rfm(task_att_qry, task_att_qry) task_specific_support_feature = self.rfm(task_att_spt, task_att_spt) return task_specific_support_feature, task_specific_query_feature
-
TACF
模块通过一系列的操作,从支持集特征中生成任务描述符,并利用这些描述符自适应地选择和增强与任务相关的特征区域。这种方法可以有效地捕捉任务的上下文信息,提高小样本学习的性能。具体来说,它利用了自注意力机制和多层感知机对特征进行深度变换,同时通过层归一化和前置归一化来稳定训练过程。最终,通过任务增强特征生成器生成的任务增强特征可以用于后续的分类任务。
-
-
ATR_Net
模型支持多种模式('fc'
,'encoder'
,'ATR_Net'
),通过设置model.module.mode
可以在不同模式之间切换,以实现不同的功能。在ATR_Net
的metric
方法中,使用normalize_feature
函数对特征进行归一化处理,有助于提高模型的稳定性和性能。ATR_Net
模型中集成了TACF
模块,用于自适应地融合任务相关的上下文信息,增强模型在小样本学习中的表现。在训练过程中,结合了元学习损失(epi_loss
)和辅助损失(loss_aux
),通过加权求和的方式得到最终的损失函数,有助于提高模型的泛化能力。使用torch.nn.utils.clip_grad_norm_
对模型的梯度进行裁剪,防止梯度爆炸,保证训练的稳定性。 -
训练集和验证集的数据是通过 dataset_builder 函数根据 args.dataset 参数动态加载的。不同的数据集对应不同的类,如 MiniImageNet, Cub 等。
-
Dataset = dataset_builder(args) trainset = Dataset('train', args) valset = Dataset('val', args)
-
训练集使用
CategoriesSampler
进行采样,每个批次包含args.way
个类别,每个类别有args.shot
个支持样本和args.query
个查询样本。 -
train_sampler = CategoriesSampler(trainset.label, len(trainset.data) // args.batch, args.way, args.shot + args.query) train_loader = DataLoader(dataset=trainset, batch_sampler=train_sampler, num_workers=6, pin_memory=True)
-
-
为了引入额外的监督信息,还构建了一个辅助训练集
trainset_aux
,并使用DataLoader
以随机打乱的方式加载数据。-
trainset_aux = Dataset('train', args) train_loader_aux = DataLoader(dataset=trainset_aux, batch_size=args.batch, shuffle=True, num_workers=8, pin_memory=True)
-
-
train 函数:首先将模型设置为训练模式,并初始化损失和准确率的记录器。然后遍历训练集和辅助训练集的数据,分别进行特征提取和损失计算。最后根据损失进行反向传播和参数更新。
-
def train(epoch, model, loader, optimizer, args=None): model.train() train_loader = loader['train_loader'] train_loader_aux = loader['train_loader_aux'] label = torch.arange(args.way).repeat(args.query).cuda() label = label.type(torch.LongTensor) label = label.cuda() loss_meter = Meter() acc_meter = Meter() k = args.way * args.shot criterion = nn.NLLLoss().cuda() tqdm_gen = tqdm.tqdm(train_loader) for i, ((data, train_labels), (data_aux, train_labels_aux)) in enumerate(zip(tqdm_gen, train_loader_aux), 1): data, train_labels = data.cuda(), train_labels.cuda() data_aux, train_labels_aux = data_aux.cuda(), train_labels_aux.cuda() model.module.mode = 'encoder' data = model(data) data_aux = model(data_aux) data_shot, data_query = data[:k], data[k:] if args.shot > 1: data_shot = data_shot.contiguous().view(args.shot, args.way, *data_shot.shape[1:]) data_shot = data_shot.mean(dim=0) model.module.mode = 'ATR_Net' logits, absolute_logits = model((data_shot, data_query)) epi_loss = F.cross_entropy(logits, label) absolute_loss = F.cross_entropy(absolute_logits, train_labels[k:]) model.module.mode = 'fc' logits_aux = model(data_aux) loss_aux = F.cross_entropy(logits_aux, train_labels_aux) loss_aux = loss_aux + absolute_loss loss = args.lamb * epi_loss + loss_aux acc = compute_accuracy(logits, label) loss_meter.update(loss.item()) acc_meter.update(acc) tqdm_gen.set_description(f'[train] epo:{epoch:>3} | avg.loss:{loss_meter.avg():.4f} | avg.acc:{acc_meter.avg():.3f} (curr:{acc:.3f})') loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 2.0) optimizer.step() optimizer.zero_grad() return loss_meter.avg(), acc_meter.avg(), acc_meter.confidence_interval()
-
-
train_main 函数:该函数是训练的主函数,负责初始化数据集、模型、优化器和学习率调度器。然后进行多个 epoch 的训练,并在每个 epoch 结束后进行验证。根据验证集的准确率保存最佳模型,并定期保存模型的快照。
-
def train_main(args): start_epoch = args.start_epoch stop_epoch = args.max_epoch max_acc, max_epoch = 0.0, 0 set_seed(args.seed) Dataset = dataset_builder(args) trainset = Dataset('train', args) train_sampler = CategoriesSampler(trainset.label, len(trainset.data) // args.batch, args.way, args.shot + args.query) train_loader = DataLoader(dataset=trainset, batch_sampler=train_sampler, num_workers=6, pin_memory=True) trainset_aux = Dataset('train', args) train_loader_aux = DataLoader(dataset=trainset_aux, batch_size=args.batch, shuffle=True, num_workers=8, pin_memory=True) train_loaders = {'train_loader': train_loader, 'train_loader_aux': train_loader_aux} valset = Dataset('val', args) val_sampler = CategoriesSampler(valset.label, args.val_episode, args.way, args.shot + args.query) val_loader = DataLoader(dataset=valset, batch_sampler=val_sampler, num_workers=6, pin_memory=True) val_loader = [x for x in val_loader] model = ATR_Net(args, resnet=args.resnet).cuda() model = nn.DataParallel(model, device_ids=args.device_ids) total = sum([param.nelement() for param in model.parameters()]) print('Number of parameter: % .2fM' % (total / 1e6)) if args.resume: resume_file = get_resume_file(args.save_path) print("resume_file:", resume_file) if resume_file is not None: tmp = torch.load(resume_file) start_epoch = tmp['epoch'] + 1 model.load_state_dict(tmp['params']) optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, nesterov=True, weight_decay=0.0005) lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.milestones, gamma=args.gamma) for epoch in range(start_epoch, args.max_epoch + 1): start_time = time.time() train_loss, train_acc, _ = train(epoch, model, train_loaders, optimizer, args) lr_scheduler.step() val_loss, val_acc, _ = evaluate(epoch, model, val_loader, args, set='val') if val_acc > max_acc: print(f'[ log ] *********A better model is found ({val_acc:.3f}) *********') max_acc, max_epoch = val_acc, epoch outfile = os.path.join(args.save_path, 'max_acc.pth') torch.save({'epoch': epoch, 'params': model.state_dict()}, outfile) if (epoch % args.save_freq == 0) or (epoch == stop_epoch - 1): outfile = os.path.join(args.save_path, '{:d}.pth'.format(epoch)) torch.save({'epoch': epoch, 'params': model.state_dict()}, outfile) epoch_time = time.time() - start_time print(f'[ log ] saving @ {args.save_path}') print(f'[ log ] roughly {(args.max_epoch - epoch) / 3600. * epoch_time:.2f} h left\n') return model
-