CVPR2022 MulT: 端到端的多任务学习transformer

An End-to-End Multitask Learning Transformer

论文:https://arxiv.org/pdf/2205.08303.pdf

code:https://github.com/IVRL/MulT

project: https://ivrl.github.io/MulT/

1.摘要

该文提出了一个端到端的多任务学习transformer框架,即 MulT,该框架可以同时学习对各高级视觉任务,包括深度估计,语义分割,reshading重着色,表面法线估计,2D关键点检测和边缘检测。基于swin-transformer模型,我们的框架将图像编码为共享表示,并使用基于特定任务的transformer解码器头对每个视觉任务进行预测。方法的核心是通过共享注意力机制对任务间的依赖关系进行建模。

通过在几个多任务基准上评估,本文提出的MulT的性能优于现有最先进的多任务卷积神经网络模型和所有各自的单任务transformer模型。

本章的实验进一步强调了在所有任务中共享注意力的好处,并证明MulT模型是稳健的,并且可以很好地泛化到新领域。

2.网络结构

如上图,Mult 模型基于swin-transformer  backbone(绿色部分),通过共享注意力机制(左下蓝色部分)对任务间的依赖关系进行建模。首先图像经encoder 编码模块(绿色部分)嵌入一个共享表示,然后通过transformer decoder解码模块(右端蓝色部分)对各个独立的任务进行解码。注意:transformer decoders具有相同的结构但接的是不同的任务头。整个模型通过监督方式采用各个任务的加权损失联合训练。

3.共享注意力机制

 为了说明任务间的依赖是在共享编码参数之外,我们设计了共享注意力机制,融合编码特征到解码流中。接下来通过一个特定的解码阶段来说明这个共享注意力机制是如何起作用的。注意在所有的解码阶段该注意力流程都有参与。

对于任务t和特定的解码阶段,x^{t} 表示为前一阶段的上采样输出,x_{sa}是同一分辨率下encode 阶段的输出。然后decoder将 x^{t}x_{sa}作为输入。标准方式来计算task t 自注意力是仅从decoder的输出x^{t}获得key,query和value 向量。

i而共享注意力,我们只利用一个任务流来计算注意力,也即,我们利用特定推理任务r的解码器的linear layers 从来自于encoder的x_{sa}计算一个query q_{sa}^{r}和key k_{sa}^{r} ,尽管如此,为了反映解码器的输出任务t应与此特定任务相关,我们计算value  v^{t}利用前一阶段任务t的输出x^{t}。因此,我们计算从推理任务r 计算attention values :

 式中C^{r} 是通道数,B^{r}是偏置。对于任务t,我们计算\tilde{x}^{t}=A_{sa}^{r}v^{t}。这里\tilde{x}^{t}后面被自注意力头head_{i}^{t} 用来计算 head_{i}^{t}(\tilde{x}_{i}^{t},W_{i}^{t})=\tilde{x}_{i}^{t}W_{i}^{t},这里 W_{i}^{t}是任务t学习到的注意力权重,\tilde{x}^{t}是第 i 通道。

注意这个方程表示自注意力的第i个实例,重复M次获得任务t的交叉注意力MHA^{t},根据这个我们计算x_{linear}^{t}通过线性投影 MHA^{t}输出,最后计算y^{t}如下:

这里W表示多头注意力权重。从经验上看,我们发现注意力来自表面法向量的任务流有利于我们6任务的MulT模型,因此我们将该任务作为参考任务r,其注意力是跨任务共享。如上图所示,x^{r}表示为前一阶段参考任务的特定编码器的上采样输出,此处作为曲面法线预测。

4.任务头和损失函数

来自transformer解码器模块的特征map被输入到不同的特定任务头,以进行后续预测。每个任务头包括一个线性层,以输出一个H×W×1的,map,其中H、W是输入图像尺寸。我们采用基于加权和的任务特定损失来联合训练网络,其中损失在每个任务的groundtruth和最终预测之间计算。对于分割,旋转,深度任务我们使用交叉熵损失,对于表面法线,2D关键点,2D边和重着色任务使用L1损失。另外,使用这些损失来保持与基线的一致性。

5.数据集

使用以下数据集评估MulT:

Taskonomy被用作我们的主要训练数据集。它包含400万幅真实的室内场景图像,每个图像的多任务注释。实验使用以下6项任务执行:语义分割(S)、深度(zbuffer)(D)、表面法线(N),2D关键点(K)、2D(Sobel)纹理边(E)和重着色(R)。选择的任务包括2D、3D和语义域,具有基于传感器/语义基础的GT。

Replica 包含1227张图像,高分辨率3D地面实况并且能够对细粒度进行更可靠的评估细节。我们在副本图像上测试了所有网络。

NYU包含1449 张来自464个不同的室内场景。

CocoDoom包含来自《末日》视频游戏的合成图像。我们将其用作未经训练的分布数据集。

6 测试效果

几篇CVPR关于multi-task的论文笔记整理,包括 一、 多任务课程学习Curriculum Learning of Multiple Tasks 1 --------------^CVPR2015/CVPR2016v--------------- 5 二、 词典对分类器驱动卷积神经网络进行对象检测Dictionary Pair Classifier Driven Convolutional Neural Networks for Object Detection 5 三、 用于同时检测和分割的多尺度贴片聚合(MPA)* Multi-scale Patch Aggregation (MPA) for Simultaneous Detection and Segmentation ∗ 7 四、 通过多任务网络级联实现感知语义分割Instance-aware Semantic Segmentation via Multi-task Network Cascades 10 五、 十字绣网络多任务学习Cross-stitch Networks for Multi-task Learning 15 --------------^CVPR2016/CVPR2017v--------------- 23 六、 多任务相关粒子滤波器用于鲁棒物体跟踪Multi-Task Correlation Particle Filter for Robust Object Tracking 23 七、 多任务网络中的全自适应特征共享与人物属性分类中的应用Fully-Adaptive Feature Sharing in Multi-Task Networks With Applications in Person Attribute Classification 28 八、 超越triplet loss:一个深层次的四重网络,用于人员重新识别Beyond triplet loss: a deep quadruplet network for person re-identification 33 九、 弱监督级联卷积网络Weakly Supervised Cascaded Convolutional Networks 38 十、 从单一图像深度联合雨水检测和去除Deep Joint Rain Detection and Removal from a Single Image 43 十一、 什么可以帮助行人检测?What Can Help Pedestrian Detection? (将额外的特征聚合到基于CNN的行人检测框架) 46 十二、 人员搜索的联合检测和识别特征学习Joint Detection and Identification Feature Learning for Person Search 50 十三、 UberNet:使用多种数据集和有限内存训练用于低,中,高级视觉的通用卷积神经网络UberNet: Training a Universal Convolutional Neural Network for Low-, Mid-, and High-Level Vision using Diverse Datasets and Limited Memory 62 一共13篇,希望能够帮助到大家
Swin Transformer是一种多任务学习框架,被用于解决多个视觉任务,如深度估计、语义分割、reshading重着色、表面法线估计、2D关键点检测和边缘检测。该框架使用了一种称为MulT端到端transformer模型,该模型将图像编码为共享表示,并使用基于特定任务的transformer解码器头对每个视觉任务进行预测。这个框架通过共享注意力机制来建模任务之间的依赖关系。 在几个多任务基准上的评估结果显示,MulT框架的性能优于现有最先进的多任务卷积神经网络模型和各自的单任务transformer模型。 在MulT框架中,来自transformer解码器模块的特征图被输入到不同的特定任务头,以进行后续预测。每个任务头包括一个线性层,用于输出一个H×W×1的特征图,其中H、W是输入图像的尺寸。网络使用基于加权和的任务特定损失来进行联合训练,该损失在每个任务的groundtruth和最终预测之间计算。对于分割、旋转和深度任务,使用交叉熵损失,对于表面法线、2D关键点、2D边和重着色任务使用L1损失。此外,还使用这些损失来保持与基线方法的一致性。<span class="em">1</span><span class="em">2</span><span class="em">3</span> #### 引用[.reference_title] - *1* *2* *3* [CVPR2022 MulT: 端到端多任务学习transformer](https://blog.csdn.net/qq_35831906/article/details/124859367)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_2"}}] [.reference_item style="max-width: 100%"] [ .reference_list ]
评论 8
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

LeapMay

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值