论文讲解:IDM: An Intermediate Domain Module for Domain Adaptive Person Re-ID

题目: IDM:域适应行人重识别的中间域模块
发表: ICCV 2021

文章链接:https://arxiv.org/pdf/2108.02413.pdf
代码链接:https://github.com/SikaStar/IDM

一. 本文主要工作:

  1. 基于“最短测地线路径”定义生成中间域。(最短测地线路径:通俗一点就是适当的中间区域应该位于连接源域和目标域的最短测地线路径上)。
    生成的中间域应该满足两种性质
    (1) 中间域与源域和目标域能保持好的距离。
    (2) 要足够多样化,以平衡源域和目标域的学习,并避免过度适应其中任何一个域。
  2. 提出两个数损失函数:bridge lossdiversity loss。(两个损失函数我翻译为中间域损失和多样性损失,第一个本意是桥接损失,我感觉翻译为中间域损失更适合文章的意思)
    bridge loss的作用:生成合适的中间域,使中间域与源域和目标域有正确的距离。此损失对应生成中间域需要满足的性质(1)
    diversity loss的作用:防止中间域偏向源域和目标域的任何一方。此损失对应生成中间域需要满足的性质(2)

二. 模型架构图

:图下标均对应原文的图标注
在这里插入图片描述

图2 (b)

解释: 此图是中间域模块的生成图片。此模块共分成五部分。
(1) 源域数据通过平均池化层和最大池化层生成张量(张量可理解为多维度的数组,但是张量不是数组)。
(2)把步骤(1)生成的两个张量进行拼接。以上两步是对源域数据进行的处理,目标域数据也是一样的。
(3) 将拼接后的源于数据张量和目标域数据张量输入到全连接层,然后对全连接层输出的数据求和。
(4)经过MLP生成两个域因子,用于对源域和目标域进行加成。( MLP是支持向量机的简称。简单来说就是简单神经网络,这里实现使用的两个全连接层)
(5)根据域因子,对源域和目标域加成,生成中间域。中间域的生成公式如下:
在这里插入图片描述

IDM模块的源代码如下:
class IDM(nn.Module):
    def __init__(self, channel=64):
        super(IDM, self).__init__()
        self.channel = channel
        self.adaptiveFC1 = nn.Linear(2*channel, channel)
        # FC2和FC3对应的是MLP
        self.adaptiveFC2 = nn.Linear(channel, int(channel/2))
        self.adaptiveFC3 = nn.Linear(int(channel/2), 2)
        self.softmax = nn.Softmax(dim=1)
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)

    def forward(self, x):

        if (not self.training):
            return x

        bs = x.size(0)
        assert (bs%2==0)
        #	 由于源于数据和目标与数据传进来的是x,是拼接在一起的,这里需要分割开
        split = torch.split(x, int(bs/2), 0)
        x_s = split[0].contiguous() # [B, C, H, W]
        x_t = split[1].contiguous()
		
		#	此处对应的就是图2(b)中的(1)(2)部分 
        x_embd_s = torch.cat((self.avg_pool(x_s.detach()).squeeze(), self.max_pool(x_s.detach()).squeeze()), 1)  # [B, 2*C]
        x_embd_t = torch.cat((self.avg_pool(x_t.detach()).squeeze(), self.max_pool(x_t.detach()).squeeze()), 1)
		#	此处对应的就是图2(b)中的(3)部分 
        x_embd_s, x_embd_t = self.adaptiveFC1(x_embd_s), self.adaptiveFC1(x_embd_t) # [B, C]
        x_embd = x_embd_s+x_embd_t
        #	此处对应的就是图2(b)中的(4)部分 
        x_embd = self.adaptiveFC2(x_embd)
        lam = self.adaptiveFC3(x_embd)
        lam = self.softmax(lam) # [B, 2]
        #	此处对应的就是图2(b)中的(5)部分 
        x_inter = lam[:, 0].reshape(-1,1,1,1)*x_s + lam[:, 1].reshape(-1,1,1,1)*x_t
        out = torch.cat((x_s, x_t, x_inter), 0)
        return out, lam

三. IDM模块使用图

在这里插入图片描述

图2 (a)
此图是IDM模块在ResNet-50上的使用图,原文做了实验,发现在第0层之后添加IDM模块,模型的效果更好。此图是在第一个ResNet块后添加,和最好的模型略有区别。

四. 模型实现

1. 域因子的生成

在这里插入图片描述
经过全连接层FC1,然后经过MLP,最后做softmax操作,得到两个域因子。 G a v g s G^s_{avg} Gavgs是源域数据经过平均池化层后的特征。
其中a是两个域因子拼接,a的结构是 在这里插入图片描述

2. 域因子生成中间域

在这里插入图片描述
对源域和目标域的特征分别使用生成的域因子累乘求和即可得到中间域特征。

3. 中间域损失
推导过程:

在这里插入图片描述

图2 (c)
1). 此图表示的是根据最短测地线距离,需满足公式

在这里插入图片描述
即源域到中间域的距离加上目标域到中间域的距离就是源域到目标域的距离。

2). 本文使用 λ \lambda λ 控制中间域在源域和目标域的位置。也就是在这里插入图片描述

3).所以可以得到源域和中间域应该满足的公式:
在这里插入图片描述
λ \lambda λ就是域相关性因子,上面说到域因子使用 a s a^s as a t a^t at表示,又由于 a s a^s as+ a t a^t at = 1(四.1 公式(1)可知,a是由softmax函数得到的值,所以和为1),则上式可改写为:
在这里插入图片描述

4) 然后就得到了中间域损失:
在这里插入图片描述

作者在中间域的预测空间和特征空间上计算中间域的损失,如下:
在这里插入图片描述
公式(6) 是带有权值的交叉熵损失,用在预测空间上测量中间域和两个域的距离。
公式(7) 使用 L 2 L_2 L2范式来度量域之间的特征距离。

4. 多样性损失

在这里插入图片描述
公式中 σ \sigma σ 是在每个mini-batch计算标准差。
最小化 L d i v L_{div} Ldiv的目的是:强制中间域尽可能多样化,以足够建模“最短测地线路径”的特征,这可以更好地桥接源领域和目标领域。

5. 总损失函数

在这里插入图片描述
L R e I D L_{ReID} LReID 是三元组损失和交叉熵损失求和(类似于别的文章的源域预训练的部分)。

6. 与其他模型的区别

一般模型会使用存储库来储存一个batch内的特征,进而计算损失,用来提升模型的性能。但是本文没有使用一个batch内的memory bank,使用的是跨批量的XMB来存储特征,进而引入了基于XBM的三元组损失来提升模型的性能。
剩余就是模型实验和算法流程,此处不再赘述。
注:此博客仅为自己见解,若有不正确的地方,很高兴和大家一起交流。本人研究小白,感谢大佬指教。

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值