题目: IDM:域适应行人重识别的中间域模块
发表: ICCV 2021
文章链接:https://arxiv.org/pdf/2108.02413.pdf
代码链接:https://github.com/SikaStar/IDM
一. 本文主要工作:
- 基于“最短测地线路径”定义生成中间域。(最短测地线路径:通俗一点就是适当的中间区域应该位于连接源域和目标域的最短测地线路径上)。
生成的中间域应该满足两种性质:
(1) 中间域与源域和目标域能保持好的距离。
(2) 要足够多样化,以平衡源域和目标域的学习,并避免过度适应其中任何一个域。 - 提出两个数损失函数:bridge loss 和diversity loss。(两个损失函数我翻译为中间域损失和多样性损失,第一个本意是桥接损失,我感觉翻译为中间域损失更适合文章的意思)
bridge loss的作用:生成合适的中间域,使中间域与源域和目标域有正确的距离。此损失对应生成中间域需要满足的性质(1)。
diversity loss的作用:防止中间域偏向源域和目标域的任何一方。此损失对应生成中间域需要满足的性质(2)。
二. 模型架构图
注:图下标均对应原文的图标注
解释: 此图是中间域模块的生成图片。此模块共分成五部分。
(1) 源域数据通过平均池化层和最大池化层生成张量(张量可理解为多维度的数组,但是张量不是数组)。
(2)把步骤(1)生成的两个张量进行拼接。以上两步是对源域数据进行的处理,目标域数据也是一样的。
(3) 将拼接后的源于数据张量和目标域数据张量输入到全连接层,然后对全连接层输出的数据求和。
(4)经过MLP生成两个域因子,用于对源域和目标域进行加成。( MLP是支持向量机的简称。简单来说就是简单神经网络,这里实现使用的两个全连接层)
(5)根据域因子,对源域和目标域加成,生成中间域。中间域的生成公式如下:
IDM模块的源代码如下:
class IDM(nn.Module):
def __init__(self, channel=64):
super(IDM, self).__init__()
self.channel = channel
self.adaptiveFC1 = nn.Linear(2*channel, channel)
# FC2和FC3对应的是MLP
self.adaptiveFC2 = nn.Linear(channel, int(channel/2))
self.adaptiveFC3 = nn.Linear(int(channel/2), 2)
self.softmax = nn.Softmax(dim=1)
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
def forward(self, x):
if (not self.training):
return x
bs = x.size(0)
assert (bs%2==0)
# 由于源于数据和目标与数据传进来的是x,是拼接在一起的,这里需要分割开
split = torch.split(x, int(bs/2), 0)
x_s = split[0].contiguous() # [B, C, H, W]
x_t = split[1].contiguous()
# 此处对应的就是图2(b)中的(1)(2)部分
x_embd_s = torch.cat((self.avg_pool(x_s.detach()).squeeze(), self.max_pool(x_s.detach()).squeeze()), 1) # [B, 2*C]
x_embd_t = torch.cat((self.avg_pool(x_t.detach()).squeeze(), self.max_pool(x_t.detach()).squeeze()), 1)
# 此处对应的就是图2(b)中的(3)部分
x_embd_s, x_embd_t = self.adaptiveFC1(x_embd_s), self.adaptiveFC1(x_embd_t) # [B, C]
x_embd = x_embd_s+x_embd_t
# 此处对应的就是图2(b)中的(4)部分
x_embd = self.adaptiveFC2(x_embd)
lam = self.adaptiveFC3(x_embd)
lam = self.softmax(lam) # [B, 2]
# 此处对应的就是图2(b)中的(5)部分
x_inter = lam[:, 0].reshape(-1,1,1,1)*x_s + lam[:, 1].reshape(-1,1,1,1)*x_t
out = torch.cat((x_s, x_t, x_inter), 0)
return out, lam
三. IDM模块使用图
四. 模型实现
1. 域因子的生成
经过全连接层FC1,然后经过MLP,最后做softmax操作,得到两个域因子。
G
a
v
g
s
G^s_{avg}
Gavgs是源域数据经过平均池化层后的特征。
其中a是两个域因子拼接,a的结构是 。
2. 域因子生成中间域
对源域和目标域的特征分别使用生成的域因子累乘求和即可得到中间域特征。
3. 中间域损失
推导过程:
即源域到中间域的距离加上目标域到中间域的距离就是源域到目标域的距离。
2). 本文使用 λ \lambda λ 控制中间域在源域和目标域的位置。也就是
3).所以可以得到源域和中间域应该满足的公式:
λ
\lambda
λ就是域相关性因子,上面说到域因子使用
a
s
a^s
as和
a
t
a^t
at表示,又由于
a
s
a^s
as+
a
t
a^t
at = 1(四.1 公式(1)可知,a是由softmax函数得到的值,所以和为1),则上式可改写为:
4) 然后就得到了中间域损失:
作者在中间域的预测空间和特征空间上计算中间域的损失,如下:
公式(6) 是带有权值的交叉熵损失,用在预测空间上测量中间域和两个域的距离。
公式(7) 使用
L
2
L_2
L2范式来度量域之间的特征距离。
4. 多样性损失
公式中
σ
\sigma
σ 是在每个mini-batch计算标准差。
最小化
L
d
i
v
L_{div}
Ldiv的目的是:强制中间域尽可能多样化,以足够建模“最短测地线路径”的特征,这可以更好地桥接源领域和目标领域。
5. 总损失函数
L
R
e
I
D
L_{ReID}
LReID 是三元组损失和交叉熵损失求和(类似于别的文章的源域预训练的部分)。
6. 与其他模型的区别
一般模型会使用存储库来储存一个batch内的特征,进而计算损失,用来提升模型的性能。但是本文没有使用一个batch内的memory bank,使用的是跨批量的XMB来存储特征,进而引入了基于XBM的三元组损失来提升模型的性能。
剩余就是模型实验和算法流程,此处不再赘述。
注:此博客仅为自己见解,若有不正确的地方,很高兴和大家一起交流。本人研究小白,感谢大佬指教。