bert pytorch源码_【PyTorch】梯度爆炸、loss在反向传播变为nan

本文详细记录了作者在使用自定义损失函数训练BERT模型时遇到的梯度爆炸和loss变为nan的问题。通过对问题的分析,发现是由于特定情况下的梯度计算导致的。在PyTorch中尝试了多种改进方法,最终通过调整损失函数的实现,成功解决了这个问题,使得反向传播过程能够正确计算梯度。
摘要由CSDN通过智能技术生成

  点击上方“MLNLP”,选择“星标”公众号

重磅干货,第一时间送达

作者丨CV路上一名研究僧

知乎专栏丨深度图像与视频增强

地址丨https://zhuanlan.zhihu.com/p/79046709

0. 遇到大坑

笔者在最近的项目中用到了自定义loss函数,代码一切都准备就绪后,在训练时遇到了梯度爆炸的问题,每次训练几个iterations后,梯度和loss都会变为nan。一般情况下,梯度变为nan都是出现了 24cb73b6-2146-eb11-8da9-e4434bdf6706.svg , 25cb73b6-2146-eb11-8da9-e4434bdf6706.svg 等情况,导致结果变为+inf,也就成了nan。

1. 问题分析

笔者需要的loss函数如下:

26cb73b6-2146-eb11-8da9-e4434bdf6706.svg

其中, 2bcb73b6-2146-eb11-8da9-e4434bdf6706.svg 。

从理论上分析,这个loss函数在反向传播过程中很可能会遇到梯度爆炸,这是为什么呢?反向传播的过程是对loss链式求一阶导数的过程,那么, 2ccb73b6-2146-eb11-8da9-e4434bdf6706.svg 的导数为:

2dcb73b6-2146-eb11-8da9-e4434bdf6706.svg

由于 2ecb73b6-2146-eb11-8da9-e4434bdf6706.svg ,这个导数又可以表示为:

30cb73b6-2146-eb11-8da9-e4434bdf6706.svg

这样的话,出现了类似于 31cb73b6-2146-eb11-8da9-e4434bdf6706.svg 的表达式,也就会出现典型的$0/1$问题了。为了避免这个问题,首先进行了如下的 2ccb73b6-2146-eb11-8da9-e4434bdf6706.svg 改变:

34cb73b6-2146-eb11-8da9-e4434bdf6706.svg

经过改变,在$x_i=0$时,不再是 36cb73b6-2146-eb11-8da9-e4434bdf6706.svg 问题了,而是转换为了一个线性函数,梯度成为了恒定的12.9,从理论上来看,避免了梯度爆炸的问题。

2. PyTorch初步实现

在实现这一过程时,依旧...遇到了大坑,下面通过示例代码来说明:

"""        loss = mse(X, gamma_inv(X))        """        def loss_function(x):
mask = (x < 0.003).float()
gamma_x = mask * 12.9 * x + (1-mask) * (x ** 0.5)
loss = torch.mean((x - gamma_x) ** 2)return lossif __name__ == '__main__':
x = Variable(torch.FloatTensor([0, 0.0025, 0.5, 0.8, 1]), requires_grad=True)
loss = loss_function(x)print('loss:', loss)
loss.backward()print(x.grad)

改进后的 2ccb73b6-2146-eb11-8da9-e4434bdf6706.svg 是一个分支结构,在实现时,就采用了类似于Matlab中矩阵计算的mask方式,mask定义为 38cb73b6-2146-eb11-8da9-e4434bdf6706.svg ,满足条件的$x_i$在mask中对应位置的值为1,因此, 3acb73b6-2146-eb11-8da9-e4434bdf6706.svg 的结构只会保留 38cb73b6-2146-eb11-8da9-e4434bdf6706.svg 的结果,同样的道理, 3ccb73b6-2146-eb11-8da9-e4434bdf6706.svg 就实现了上述改进后的 2ccb73b6-2146-eb11-8da9-e4434bdf6706.svg 公式。

按理来说,此时,在反向传播过程中的梯度应该是正确的,但是,上面代码的输出结果为:

loss: tensor(0.0105, grad_fn=)
tensor([ nan, 0.1416, -0.0243, -0.0167, 0.0000])

emmm....依旧为nan,问题在理论层面得到了解决,但是,在实现层面依旧没能解决.....

3. 源码调试分析

上面源码的问题依旧在 2ccb73b6-2146-eb11-8da9-e4434bdf6706.svg 的实现,这个过程,在Python解释器解释的过程或许是这样的:

  1. 计算 40cb73b6-2146-eb11-8da9-e4434bdf6706.svg ,对mask进行广播式的乘法,结果为:原本为1的位置变为了12.9,原本为0的位置依旧为0;

  2. 将1.的结果继续与x相乘,本质上仍然是与x的每个元素相乘,只是mask中不满足条件的 43cb73b6-2146-eb11-8da9-e4434bdf6706.svg 位置为0,表现出的结果是仅对满足条件的 43cb73b6-2146-eb11-8da9-e4434bdf6706.svg 进行了计算;

  3. 按照2.所述的原理, 2ccb73b6-2146-eb11-8da9-e4434bdf6706.svg 公式的后半部分也是同样的计算过程,即, 49cb73b6-2146-eb11-8da9-e4434bdf6706.svg 中的每个值依旧会进行 4bcb73b6-2146-eb11-8da9-e4434bdf6706.svg 的计算;

按照上述过程进行前向传播,在反向传播时,梯度不是从某一个分支得到的,而是两个分支的题目相加得到的,换句话说,依旧没能解决梯度变为nan的问题。

4. 源码改进及问题解决

经过第三部分的分析,知道了梯度变为nan的根本原因是当 4dcb73b6-2146-eb11-8da9-e4434bdf6706.svg 时依旧参与了 4fcb73b6-2146-eb11-8da9-e4434bdf6706.svg 的计算,导致在反向传播时计算出的梯度为nan。

要解决这个问题,就要保证在 4dcb73b6-2146-eb11-8da9-e4434bdf6706.svg 时不会进行这样的计算。

新的PyTorch代码如下:

def loss_function(x):
mask = x < 0.003 gamma_x = torch.FloatTensor(x.size()).type_as(x)
gamma_x[mask] = 12.9 * x[mask]
mask = x >= 0.003 gamma_x[mask] = x[mask] ** 0.5 loss = torch.mean((x - gamma_x) ** 2)return lossif __name__ == '__main__':
x = Variable(torch.FloatTensor([0, 0.0025, 0.5, 0.8, 1]), requires_grad=True)
loss = loss_function(x)print('loss:', loss)
loss.backward()print(x.grad)

改变的地方位于`loss_function`,改变了对于 2ccb73b6-2146-eb11-8da9-e4434bdf6706.svg 分支的处理方式,控制并保住每次计算仅有满足条件的值可以参与。此时输出为:

loss: tensor(0.0105, grad_fn=)
tensor([ 0.0000, 0.1416, -0.0243, -0.0167, 0.0000])

就此,问题解决!

如有疑问,欢迎留言~

a4dd5530112c254894ef9a928bb8bac5.png

推荐阅读:

实战 | Pytorch BiLSTM + CRF做NER

如何评价Word2Vec作者提出的fastText算法?深度学习是否在文本分类等简单任务上没有优势?

从Word2Vec到Bert,聊聊词向量的前世今生(一)

65db891eb155ebd436793f78e5dbdf9c.png

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值