论文地址:https://arxiv.org/abs/2301.03033
代码地址:https://github.com/liuzywen/RGBTCC/tree/main
目的:解决RGB-T多模态下的人群计数,提出count-guided multi-modal fusion 和 modal-guided count enhancement
全文内容:
1. 解决人群计数下RGB-T任务,主要是解决多模态融合问题。之前的一些方法融合互补的多模态特征,但是缺乏计数约束。我们就考虑在多模态融合过程中添加计数约束,双模态融合就有了明确的目标。本文采取一个transformer架构来融合多模态并且设计了一个可学习的计数token来参与这个过程。这使得多模态融合能在计数指导下融合。
2. 另一个问题是尺度的大规模变化。这个是由于拍摄时人头近大远小导致的。之前一些方法用的是多列(MCNN),膨胀卷积,高分辨率,注意力机制来扩大感受野。在transformer框架下,我们提出了一个多尺度的token transformer来感知不同尺度的人。token被合并以形成具有不同长度的token序列,然后被送到一些并行 transformer中。这样处理后,特征的感受野就会多样。
backbone:
网络图如上,RGB和热图分别送入两个PVT encoder(不共享权重),PVT是一个多尺度transformer,会输出四阶段特征。high-layer特征语义信息更多,适合用于全局计数。因此我们拿两个模态从PVT最后一层的特征 Fr4+Ft4+一个额外的count token(有点像cls),沿着token方向(就是transformer输出的第二个维度,dim=1)做一个cat操作。然后把这个cat后的特征送到一个多尺度token transformer中(MSTTrans)。
MSTTrans
MSTTrans是为了解决大规模变化而提出的。灵感来自ASPP中多尺度设计。
首先Fr4+Ft4+一个额外的count token记为initial token sequence f1
具体来说,Fr4+Ft4 大小是N2 X C,N是token长度,C是channel大小。所以f1的大小是(2N2+1)XC
[·] 是concat沿着token 方向。
然后,将两个模态特征Fr4,Ft4分别分为N组,每组生成N个中尺度token。中尺度token再和 learnable count token cat在一起得到middle-scale token sequence f2.
f2大小是(2N+1)X C,其中merge操作是reshape+全连接。
再进行一遍相似的操作,得到f3 large-scale token sequence,f3大小是(2+1)XC
f1,f2,f3三个token sequence被送到三个多头自注意模块中,用于多模态交互。
其中MHSA表示两个多头自注意层。
f2和f3和f1的尺度不同,得到结果f2’和f3’后采用全连接层和reshape来恢复大小。
i只取2,3 。然后再把他们和f1cat在一起,沿着通道方向。MLP就是transformer中经典的两层FC。
最后就是g2',g3’和f1’ cat 后MLP一下。
G和MSTTrans输入f1一个维度,然后包含了优化后的r特征,t特征和count token。
Modal-guided counting enhancement
使用热模态来预测密度图和计数,并进一步使用RGB模态来细化预测。MSDTrans就算用了一下DETR中的multi-scale deformable attention,把Gt和Gcount cat后作为Q,Gr为K,F1r、F2r、F3r作为V放进multi-scale deformable attention。
得到模态增强后的Qt和Qcount。
回归头和loss
回归头:两个3X3卷积+一个1x1
loss是DM-COUNT的loss+一个L1范式监督count token
实验
数据集:RGBT-CC
评价指标:Grid Average Mean Absolute Error (GAME) 和 Root Mean Square Error (RMSE)
配置:GPU (NVIDIA RTX 3090); (2) input size(224×224); (3) train time (17 hours); (4) learning rate (1e−5); (5) weight decay (1e−4).
结果:
消融:
其中“Ours/count" 表示去learnable count token,“Ours/multi-scale" 表示普通多头 multihead self-attention替代所提multi-scale token transformer.