多任务学习中的梯度归一,GradNorm

该博客介绍了GradNorm,一种用于深度多任务网络的梯度归一化技术,旨在动态平衡不同任务的损失。通过调整任务梯度的幅度,GradNorm能促进网络性能提升,减少过拟合,而无需复杂的超参数调优。文中解释了如何使用梯度的L2范数来代表损失量级,并展示了如何通过梯度归一化来控制多任务学习中的训练速度。此外,还提供了代码片段来说明如何计算任务的梯度范数。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

一、论文中的一些重点内容

论文链接:https://arxiv.org/pdf/1711.02257.pdf,看这篇论文的作者是一家致力于将3D立体图像展示在现实生活中,最好是不用带特定的眼镜就可以肉眼融合,如果能做到这一步,那么对人类各个行业都会有很大的改变,比如建筑、制造行业,就可以用手来动态操作立体图像(就像钢铁侠中那种设计一样)

(1) 我们跨任务而不是跨数据批次标准化
(2) 我们使用速率平衡作为期望的目标来通知我们的标准化。我们将证明这种梯度归一化(以下称为 GradNorm)提升了网络性能,同时显着减少过拟合

我们对多任务学习的主要贡献如下:
1.一种高效的多任务损失平衡算法,它直接调整不同任务的梯度幅度。
2.匹配或超越性能的方法,非常昂贵的详尽网格搜索程序,但这只需要调整一个超参数。
3. 直接梯度交互提供了一种控制多任务学习的强大方式的演示。

(1) 将不同任务的梯度范数放在一个共同的尺度下面,我们可以通过它来推断它们的相对大小
(2) 动态调整梯度范数,以至于不同的任务以相似的速度训练。 为此,我们首先定义相关量,需要是和梯度相关。

优秀文章:GradNorm:Gradient Normalization for Adaptive Loss Balancing in Deep Multitask Networks,梯度归一化_Leon_winter的博客-CSDN博客_gradnorm

二、重点讲解

GiW(t)是任务 i 的带权损失,对需要更新的神经网络参数W(W表示神经网络参数,wi表示label loss中各个label的权重)的梯度的L2范数,这里用 GiW(t) 代表 loss的量级,为什么可以用权重的梯度的L2范数代表loss量级,可以具体看下面这篇文章讲解每个参数反向传播时,如果loss越大,反向传播时候的每个参数的梯度也就越大

一文弄懂神经网络中的反向传播法——BackPropagation - Charlotte77 - 博客园

公式中的i代表第i个学习任务(ctr cvr可以当算作一个i),Li(0)代表第0次迭代时候的loss,Li(t)代表第t次迭代时候的loss,这两个相比,越大说明到第t次时迭代越慢(和Li(0)的差距很小)。那么 Li(t)~ 和 Etask[Li(t)~](代表的是第t次迭代时,多个目标loss的均值,如果是ctr cvr两个目标的话,那就就是ctr和cvr的平均loss;而且分母中对于不同任务是定值,大小取决于分子,你的训练速度慢,分子就大,最终的值就大)的比值 ri(t) 也就越大ri(t)也叫做相对反向训练速度(任务迭代的越慢,ri(t)越大)

上面的公式证明了两点,1:训练loss过大时;2:训练速度过快时,都会导致【L(grad)(t,wit) -> 也可以称为:Grad Loss】过大,但是为什么Grad Loss过大,会导致wi变小

因为Grad Loss过大,我们对它求导,导数值(梯度更新公式中的β(t))也就越大,w(t+1)=w(t) + lambda * β(t),lambda在文章中说是学习率,但是这里的 + 应该是 - 才对,那么wi就会变小

三、一些track

1:α是超参数,α \alphaα越大,对训练速度的平衡限制越强。为了节约计算时间,G r a d   L o s s Grad~LossGrad Loss仅对shared bottom的输出部分计算。
2:

3:

 4:公式中的GiW(t)在代码中是怎么体现的

# get layer of shared weights
W = model.get_last_shared_layer()

# get the gradient norms for each of the tasks
# G^{(i)}_w(t) 
norms = []
for i in range(len(task_loss)):
# get the gradient of this task loss with respect to the shared parameters
    gygw = torch.autograd.grad(task_loss[i], W.parameters(), retain_graph=True)
    # compute the norm
    norms.append(torch.norm(torch.mul(model.weights[i], gygw[0])))

整个shared_bottomk会有连续多层,最后一层的参数就是就是上面的W,然后每个任务的Loss(i)对W进行求导,得到梯度,再求梯度的L2范数。最后再和权重model.weight[i]相乘

评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值