Multi-task learning MTL 中的多任务loss平衡问题
背景
multi-task的损失函数:
L ( t ) = ∑ w i ( t ) L i ( t , θ ) L(t)=\sum{w_i(t)L_i(t,\theta)} L(t)=∑wi(t)Li(t,θ)
在multi-task训练中存在着:
- 如何平衡各个任务损失的权重,
- 对于不同的任务的loss梯度之间的大小关系如果平衡,
- 各任务学习率如何控制.
这些问题影响mult-task训练的最终效果. 处理不当, 很有可能一个task学的很好, 其他task学的很差.
三个问题是相通的.
7 Nov 2017 - GradNorm: Gradient Normalization for Adaptive Loss Balancing in Deep Multitask Networks
论文地址
GradNorm的解决方法, 将 w i w_i wi作为可学习的参数.
目标是:
- 在不考虑学习率的情况下, 尽量使得各个task的 w i ( t ) L i ( t , θ ) w_i(t)L_i(t,\theta) wi(t)Li(t,θ)对于参数 θ \theta θ的gradient都与平均值接近.
- 学习不充分的task, 给与比其他task更大的学习率.
t t t时间task i i i的 θ \theta θ梯度:
G θ ( i ) ( t ) = ∥ ▽ θ w i ( t ) L i ( t , θ ) ∥ 2 G_\theta^{(i)}(t)=\Vert \bigtriangledown_{\theta}w_i(t)L_i(t,\theta)\Vert_2 Gθ(i)(t)=∥▽θwi(t)Li(t,θ)∥2
所有task对 θ \theta θ的平均梯度:
G ‾ θ ( t ) = E t a s k [ G θ ( i ) ( t ) ] \overline{G}_\theta(t)=E_{task}[G_\theta^{(i)}(t)] Gθ(t)=Etask[Gθ(i)(t)]
对于学习率, 定义若干个变量:
定义一个loss相对于初始化时的占比, 优化程度.
L ~ i ( t ) = L i ( t ) / L i ( 0 ) \widetilde{L}_i(t)=L_i(t)/L_i(0)