focal loss dice loss源码_扒源码:sharding loss in Pytorch

d362e1374e9b3dd40f6a3a06c2d92ebd.png
源码: https:// github.com/OpenNMT/Open NMT-py/blob/master/onmt/utils/loss.py

OpenNMT 在计算 NMT 模型的loss时,进行了“shard”处理,即对解码序列 y 分段进行梯度回溯。如果不这样做,超长序列的 loss.backward() 时会导致显存爆炸。

写一点小代码来演示一下。比如 一个简单的网络,输入 x,输入 y:

import 

定义 loss 为求和,梯度回退后打印 x 的梯度出来看看:

loss = torch.sum(y)
loss.backward()
print(x.grad)
# tensor([ 2.,  4.,  6.]

梯度计算正确(即 2x)。以上是普通流程。现在开始分段,直接对每个 shard 回溯:

y_list = torch.split(y, 1)
for z in y_list:
    loss = torch.sum(z)
    loss.backward()
    print(x.grad)
# tensor([ 2.,  0.,  0.])
# RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

在循环的第一步,x 得到了梯度 [2, 0, 0],确实是 y[0] 的梯度回溯结果。可是在循环的第二步,同一个 graph 连续两次进行 backward,报错了。加上 “retain_graph=True”可以解决这个问题,不过那和不分 shard 也没区别了。

正确的做法是,每个 shard 单独计算 loss,但是只回溯到 y,阻断到网络内部(即 x)的更新。在遍历结束后,统一更新。请看注释。

v_list = []
g_list = []

for z in torch.split(y, 1):
    t = z.data.clone()       # 只对data进行clone(),阻断到网络内部的梯度传递;
    t.requires_grad = True   # 将t变成leaf node;

    loss = torch.sum(t)      # 计算loss。OpenNMT将generator包在了里面,shard后更省显存;
    loss.backward()          # 只能计算出t的梯度, 到不了x;

    v_list.append(z)         # 注意t只是计算梯度用的,真正连接着网络内部的是z,即y的shard;
    g_list.append(t.grad)    # 把t的梯度保存下来,它和z的梯度是一样的;

# 对v_list中的每一个shard,使用g_list中对应的梯度,继续回溯直至x;
torch.autograd.backward(v_list, g_list)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值