【多任务损失】用不确定度来自适应设定各损失函数的权重


🌔在多任务中,总损失常常是各分任务损失的线性加权,此时各分任务损失的权重设定就显得尤为重要。

比如在JDE算法中就使用了多任务损失的自适应设定:
在这里插入图片描述
这里使用的权重自适应方法便来自以下这篇文章发掘的不确定度与权重的关系:

Multi-Task Learning Using Uncertainty to Weigh Losses for Scene Geometry and Semantics

本文正是基于此,用步步衔接的数学推导,讲解多任务损失函数权重设定原理。

总的来说用不确定度估计多任务损失权重的核心是:

各任务的权重与其对应的噪声成反相关

0️⃣ 多任务独立性

⭐️对于模型输出 f W ( x ) \mathbf{f^W(x)} fW(x)以及对应的真值 y \mathbf{y} y,多任务(假设是相互独立的)的概率估计可以表示为以下这种形式:
p ( y 1 , … , y K ∣ f W ( x ) ) = p ( y 1 ∣ f W ( x ) ) … p ( y K ∣ f W ( x ) ) p(\mathbf{y_1,\dots,y_K|f^W(x)})=p(\mathbf{y_1|f^W(x)})\dots p(\mathbf{y_K|f^W(x)}) p(y1,,yKfW(x))=p(y1fW(x))p(yKfW(x))
假设使用 − log ⁡ -\log log来计算损失,则有:
L t o t a l = − log ⁡ [ p ( y 1 , … , y K ∣ f W ( x ) ) ] = ( − log ⁡ [ p ( y 1 ∣ f W ( x ) ) ] ) + ⋯ + ( − log ⁡ [ p ( y K ∣ f W ( x ) ) ] ) = L 1 + ⋯ + L K L_{total}\\=-\log[p(\mathbf{y_1,\dots,y_K|f^W(x)})]\\=(-\log[p(\mathbf{y_1|f^W(x)})])+\dots+ (-\log[p(\mathbf{y_K|f^W(x)})])\\=L_1+\dots+L_K Ltotal=log[p(y1,,yKfW(x))]=(log[p(y1fW(x))])++(log[p(yKfW(x))])=L1++LK
因而,基于此就可以使用概率估计来融合各分任务的损失表达。

1️⃣ 连续类损失(回归)

⭐️以回归问题为例,讲解对连续类损失的处理方式。

假设回归问题的概率满足Gaussian分布,依此估计其概率,则有:
p ( y ∣ f W ( x ) , σ ) = 1 2 π σ e − ( y − f W ( x ) ) 2 2 σ 2 p(\mathbf{y|f^W(x),\sigma})=\frac{1}{\sqrt{2\pi}\sigma}e^{-\frac{(y-f^W(x))^2}{2\sigma^2}} p(y∣fW(x),σ)=2π σ1e2σ2(yfW(x))2
写成损失的形式:
− log ⁡ ( y ∣ f W ( x ) , σ ) = ( y − f W ( x ) ) 2 2 σ 2 − log ⁡ ( 1 2 π σ ) ≈ ∥ y − f W ( x ) ∥ 2 2 σ 2 + log ⁡ ( σ ) -\log(\mathbf{y|f^W(x),\sigma})={\frac{(y-f^W(x))^2}{2\sigma^2}}-\log(\frac{1}{\sqrt{2\pi}\sigma})\approx {\frac{\parallel y-f^W(x)\parallel_2}{2\sigma^2}}+\log({\sigma}) log(y∣fW(x),σ)=2σ2(yfW(x))2log(2π σ1)2σ2yfW(x)2+log(σ)

2️⃣ 离散类损失(分类)

⭐️以分类问题为例,讲解对离散类损失的处理方式。

使用Softmax对分类输出进行处理,依此估计其概率,其中 σ \sigma σ是一个尺度因子:
p ( y = c ∣ f W ( x ) , σ ) = S o f t m a x ( y , 1 σ 2 f W ( x ) ) = e 1 σ 2 f c W ( x ) ∑ c ′ e 1 σ 2 f c ′ W ( x ) p(\mathbf{y}=c|\mathbf{f^W(x),\sigma})=Softmax(\mathbf{y},\frac{1}{\sigma^2}\mathbf{f^W(x)})=\displaystyle \frac{e^{\frac{1}{\sigma^2}f^W_c(\mathbf{x})}}{\sum_{c'}e^{\frac{1}{\sigma^2}f^W_{c'}(\mathbf{x})}} p(y=cfW(x),σ)=Softmax(y,σ21fW(x))=ceσ21fcW(x)eσ21fcW(x)
写成损失的形式:
− log ⁡ ( p ( y = c ∣ f W ( x ) , σ ) ) = − 1 σ 2 f c W ( x ) + log ⁡ ( ∑ c ′ e 1 σ 2 f c ′ W ( x ) ) -\log(p(\mathbf{y}=c|\mathbf{f^W(x),\sigma}))=-\frac{1}{\sigma^2}f^W_c(\mathbf{x})+\log({\sum_{c'}e^{\frac{1}{\sigma^2}f^W_{c'}(\mathbf{x})}}) log(p(y=cfW(x),σ))=σ21fcW(x)+log(ceσ21fcW(x))

3️⃣ 多任务损失组合

⭐️假设有两个任务:回归任务1(连续),分类任务2(离散),则有:
L ( W , σ 1 , σ 2 ) = − log ⁡ ( p ( y 1 , y 2 = c ∣ f W ( x ) ) ) = − log ⁡ [ ( 1 2 π σ 1 e − ( y 1 − f W ( x ) ) 2 2 σ 1 2 ) ⋅ ( e 1 σ 2 2 f c W ( x ) ∑ c ′ e 1 σ 2 2 f c ′ W ( x ) ) ] = ∥ y 1 − f W ( x ) ∥ 2 2 σ 1 2 + log ⁡ ( σ 1 ) − 1 σ 2 2 f c W ( x ) + log ⁡ ( ∑ c ′ e 1 σ 2 2 f c ′ W ( x ) ) = ∥ y 1 − f W ( x ) ∥ 2 2 σ 1 2 + log ⁡ ( σ 1 ) + [ − 1 σ 2 2 f c W ( x ) + 1 σ 2 2 log ⁡ ( ∑ c ′ e f c ′ W ( x ) ) ] − 1 σ 2 2 log ⁡ ( ∑ c ′ e f c ′ W ( x ) ) + log ⁡ ( ∑ c ′ e 1 σ 2 2 f c ′ W ( x ) ) = 1 2 σ 1 2 L 1 ( W ) + 1 σ 2 2 L 2 ( W ) + log ⁡ ( σ 1 ) + log ⁡ [ ∑ c ′ e 1 σ 2 2 f c ′ W ( x ) ( ∑ c ′ e f c ′ W ( x ) ) 1 σ 2 2 ] ≈ 1 2 σ 1 2 L 1 ( W ) + 1 σ 2 2 L 2 ( W ) + log ⁡ ( σ 1 ) + log ⁡ ( σ 2 ) L(\mathbf{W},\sigma_1,\sigma_2)\\=-\log(p(\mathbf{y_1,y_2}={c}|\mathbf{f^W(x)}))\\=-\log[(\frac{1}{\sqrt{2\pi}\sigma_1}e^{-\frac{(y_1-f^W(x))^2}{2\sigma_1^2}})\cdot(\displaystyle \frac{e^{\frac{1}{\sigma_2^2}f^W_c(\mathbf{x})}}{\sum_{c'}e^{\frac{1}{\sigma_2^2}f^W_{c'}(\mathbf{x})}})]\\={\frac{\parallel y_1-f^W(x)\parallel_2}{2\sigma_1^2}}+\log({\sigma_1})-\frac{1}{\sigma_2^2}f^W_c(\mathbf{x})+\log({\sum_{c'}e^{\frac{1}{\sigma_2^2}f^W_{c'}(\mathbf{x})}}) \\ ={\frac{\parallel y_1-f^W(x)\parallel_2}{2\sigma_1^2}}+\log({\sigma_1})+[-\frac{1}{\sigma_2^2}f^W_c(\mathbf{x})+\frac{1}{\sigma_2^2}\log(\sum_{c'}e^{f^W_{c'}(x)})]-\frac{1}{\sigma_2^2}\log(\sum_{c'}e^{f^W_{c'}(x)})+\log({\sum_{c'}e^{\frac{1}{\sigma_2^2}f^W_{c'}(\mathbf{x})}}) \\ =\frac{1}{2\sigma_1^2}L_1(\mathbf{W})+\frac{1}{\sigma_2^2}L_2(\mathbf{W})+\log(\sigma_1)+\log[\frac{\sum_{c'}e^{\frac{1}{\sigma_2^2}f^W_{c'}(\mathbf{x})}}{(\sum_{c'}e^{f^W_{c'}(\mathbf{x})})^\frac{1}{\sigma_2^2}}]\\\approx \frac{1}{2\sigma_1^2}L_1(\mathbf{W})+\frac{1}{\sigma_2^2}L_2(\mathbf{W})+\log(\sigma_1)+\log(\sigma_2) L(W,σ1,σ2)=log(p(y1,y2=cfW(x)))=log[(2π σ11e2σ12(y1fW(x))2)(ceσ221fcW(x)eσ221fcW(x))]=2σ12y1fW(x)2+log(σ1)σ221fcW(x)+log(ceσ221fcW(x))=2σ12y1fW(x)2+log(σ1)+[σ221fcW(x)+σ221log(cefcW(x))]σ221log(cefcW(x))+log(ceσ221fcW(x))=2σ121L1(W)+σ221L2(W)+log(σ1)+log[(cefcW(x))σ221ceσ221fcW(x)]2σ121L1(W)+σ221L2(W)+log(σ1)+log(σ2)
其中,
{ L 1 ( W ) = ∥ y 1 − f W ( x ) ∥ 2 L 2 ( W ) = − log ⁡ [ S o f t m a x ( y 2 , f W ( x ) ) ] = − log ⁡ ( e f c W ( x ) ∑ c ′ e f c ′ W ( x ) ) 1 σ 2 2 ∑ c ′ e 1 σ 2 2 f c ′ W ( x ) ≈ ( ∑ c ′ e f c ′ W ( x ) ) 1 σ 2 2 \begin{cases}L_1(\mathbf{W})=\parallel y_1-f^W(x)\parallel_2\\L_2(\mathbf{W})=-\log[Softmax(y_2,f^W(x))]=-\log(\displaystyle \frac{e^{f^W_c(\mathbf{x})}}{\sum_{c'}e^{f^W_{c'}(\mathbf{x})}})\\\frac{1}{\sigma_2^2}\sum_{c'}e^{\frac{1}{\sigma^2_2}f^W_{c'}(x)}\approx(\sum_{c'}e^{f^W_{c'}(\mathbf{x})})^\frac{1}{\sigma_2^2}\end{cases} L1(W)=∥y1fW(x)2L2(W)=log[Softmax(y2,fW(x))]=log(cefcW(x)efcW(x))σ221ceσ221fcW(x)(cefcW(x))σ221

4️⃣ 实际操作中的近似处理

⭐️最后,在实际操作中,通过预测 s : = log ⁡ ( σ 2 ) s:=\log(\sigma^2) s:=log(σ2)来代替预测 σ 2 \sigma^2 σ2,因为这样在数值上更加稳定,而且没有除0等问题,则最后多任务损失可以近似写成:
L ( W , s 1 , s 2 ) = 1 e s 1 L 1 ( W ) + 1 e s 2 L 2 ( W ) + s 1 + s 2 L(\mathbf{W},s_1,s_2)=\frac{1}{e^{s_1}}L_1(\mathbf{W})+\frac{1}{e^{s_2}}L_2(\mathbf{W})+s_1+s_2 L(W,s1,s2)=es11L1(W)+es21L2(W)+s1+s2

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值