Context-Transformer: Tackling Object Confusion for Few-Shot Detection(AAAI20)
论文题目: Context-Transformer: Tackling Object Confusion for Few-Shot Detection
论文地址:https://arxiv.org/pdf/2003.07304.pdf
Introduction
-
少样本目标检测使用的一个常用策略是迁移学习(transfer learning)——在源域数据集上进行预训练,再在目标域上进行微调。出现的问题是:定位可以很好,但是分类效果不好,例如将马识别为狗,但是定位很准确。
-
分析原因:边框回归(BBOX)是类别无关的,目标分类和背景分类(OBJ+BG)是类别相关的,数据多样性低导致分类很容易混淆。为了解决这个问题,本文提出一种新的Context-Transformer,利用图片的上下文信息提高分类的准确度。比如区分物体是马还是狗,可以通过是否有人骑在上面进行分析判断。
Source Detection Transfer
-
整体设置
在一个大规模源域数据集上预训练得到一个识别 C s C_s Cs个物体类别的模型;目的是实现目标域上的少样本目标检测,具体来说,目标域包括 C t C_t Ct个物体类别,目标域上训练集仅包含 N shot;源域和目标域的物体类别是不重复的,用以建议模型对新类别的泛化程度。
-
检测基本框架采用SSD,具有多尺度检测的特点,能包含更多的上下文信息。对于每一个尺度,都包含一个边框回归(BBOX)、目标和背景分类(OBJ+BG)。
-
源域检测的迁移设置
- 源域中预训练的 BBOX,BG(二分类,识别是物体还是背景) 能够很好的迁移到目标域中。
- 对于 OBJ ,替换源域预训练的 OBJ 会引入大量随机初始化的参数,对于少样本训练不利。所以本文在源域训练的 OBJ 的基础上添加一个目标域的 OBJ ,这样减少了参数和过拟合。
- 在源域 OBJ 和目标域 OBJ 之间添加一个 Context-Transformer,嵌入该模块可以利用上下文信息,增强分类效果
-
整体框架
分成上下两个部分,上部分是常规的检测框架,下半部分为Context-Transformer模块。首先将target-domain的image输入在source-domain上训练好的SSD,然后将SSD其中一个输出OBJ输入Context-Transformer模块做一个特征增强,然后将增强后的特征和BG一起送入分类器进行分类。
Context-Transformer
主要包含两个模块:affinity discovery 和 context aggregation
Affinity Discovery
-
Prior Box
将一张target-domain的图片输入预训练的SSD中,然后抽取OBJ部分获取的向量
P k ∈ R H k × W k × ( M k × C s ) P_k \in R^{H_k \times W_k \times (M_k \times C_s) } Pk∈RHk×Wk×(Mk×Cs)
P k P_k Pk可以看作目标域上第 k k k 个尺度的第 m m m 个比率先验框在位置 ( h , w ) (h,w) (h,w)对应种类的预测向量。 -
Contextual Field
上下文的信息通过 Prior Box 构建,一种简单的想法就是直接使用所有的先验框作为 Contextual Field ,但是先验框的数量太多,一对多的信息对比会增加学习难度。因此本文在 P k P_k Pk 上做池化(关注整体信息,而不是小细节)获得 Q k ∈ R U k × V k × ( M k × C s ) Q_k \in R^{U_k \times V_k \times (M_k \times C_s) } Qk∈RUk×Vk×(Mk×Cs)
Q k = S p a t i a l P o o l ( P k ) , k = 1 , . . . , K Q_k = SpatialPool(P_k),\ k=1,...,K Qk=SpatialPool(Pk), k=1,...,K
U k × V k U_k \times V_k Uk×Vk是池化后的对应的大小。 -
Affinity Discovery
获取Prior Box和Contextual Field后,寻找计算他们之间的相关度分数。
首先将候选区域 P 和 Q reshape 成新的向量,具体表示如下:
P ∈ R D p × C s , Q ∈ R D q × C s D p = ∑ k = 1 K H k × W k × M k , D q = ∑ k = 1 K U k × V k × M k P \in R^{D_p \times C_s},\ Q\in R^{D_q \times C_s} \\ D_p = \sum_{k=1}^{K}H_k \times W_k \times M_k , D_q = \sum_{k=1}^{K}U_k \times V_k \times M_k P∈RDp×Cs, Q∈RDq×CsDp=k=1∑KHk×Wk×Mk,Dq=k=1∑KUk×Vk×Mk
然后利用矩阵乘法获得相关度分数矩阵 A ∈ R D p × D q A \in R^{D_p \times D_q} A∈RDp×Dq,计算过程如下:
A = f ( P ) × g ( Q ) T f ( P ) ∈ R D p × C s , g ( Q ) ∈ R D q × C s A = f(P) \times g(Q)^T \\ f(P) \in R^{D_p \times C_s}, \ \ \ g(Q) \in R^{D_q \times C_s} A=f(P)×g(Q)Tf(P)∈RDp×Cs, g(Q)∈RDq×Cs
f,g 为全连接层,相关度矩阵帮助先验框找到重要的对应上下文域。
Context Aggregation
获取获取Prior Box和Contextual Field对应的相关性矩阵后,将上下文信息融合到先验框中来增强分类效果,具体过程如下:
-
对于 A 的每一行进行softmax运算, s o f t m a x ( A ( i , : ) ) softmax(A(i,:)) softmax(A(i,:)) 表示每个contextual field对于先验框的重要性。利用下面公式获取每个先验框对应所有上下文信息的权重向量, h h h 为全连接层。
L ( i , : ) = s o f t m a x ( A ( i , : ) ) × h ( Q ) h ( Q ) ∈ R D q × C s L(i,:) = softmax(A(i,:))\times h(Q) \\ h(Q)\in R^{D_q \times C_s} L(i,:)=softmax(A(i,:))×h(Q)h(Q)∈RDq×Cs -
最后将上下文信息融合到先验框中,获取具有上下文信息感知的先验框 P ^ ∈ R D p × C s \widehat{P} \in R^{D_p \times C_s} P ∈RDp×Cs, φ \varphi φ 对应全连接层
P ^ = P + φ ( L ) φ ( L ) ∈ R D q × C s \widehat{P} = P + \varphi (L) \\ \varphi(L) \in R^{D_q \times C_s} P =P+φ(L)φ(L)∈RDq×Cs
Target OBJ
将 Context-Transformer 获得的
P
^
\widehat{P}
P
送入目标域类别分类器中,获取目标域类别分数矩阵:
Y
^
=
s
o
f
t
m
a
x
(
P
^
×
Θ
)
Y
^
∈
R
D
p
×
C
t
,
Θ
∈
R
C
s
×
C
t
\widehat{Y} = softmax(\widehat{P}\times \Theta) \\ \widehat{Y} \in R^{D_p \times C_t} ,\ \ \ \Theta\in R^{C_s \times C_t}
Y
=softmax(P
×Θ)Y
∈RDp×Ct, Θ∈RCs×Ct
Experiments
-
实验设置
source-domain:COCO中与VOC07+12不重叠的60类所涵盖的图片
target-domain:VOC07+12 -
Context-Transformer的效果
Baseline表示只用最原始的fine-tune方法,OBJ(S)代表保留source的OBJ。从表格中可看出,保留source的OBJ可以减轻过拟合,添加Context-Transformer模块可以减轻目标分类混淆的情况,它们都能提升mAP。作者还在test的时候移除了Context-Transformer模块,指标仅比在test阶段使用Context-Transformer模块略有下降,说明训练阶段使用的Context-Transformer模块使得检测器泛化能力增强。 -
shot数目对于模型效果的影响
增加shot并没有提升识别的效果,shot增加一定量的时候,模型的性能趋于稳定。shot量足够的时候,造成目标confusion的原因消失了,这时候Context-Transformer模块对于减轻confusion的作用已经很小了 -
和其他方法对比
对比了2个比较早期的基于SSD检测方法,提升了好几个点。
使用其它few-shot learner替换Context-Transformer
Conclusion
迁移学习的方法,在SSD检测框架中使用上下文信息来提高分类的效果,从而提升识别的效果。
参考:https://blog.csdn.net/chenxi1900/article/details/109347872