梯度剪裁: torch.nn.utils.clip_grad_norm_()


前言

当神经网络深度逐渐增加,网络参数量增多的时候,反向传播过程中链式法则里的梯度连乘项数便会增多,更易引起梯度消失和梯度爆炸。对于梯度爆炸问题,解决方法之一便是进行梯度剪裁,即设置一个梯度大小的上限。本文介绍了pytorch中梯度剪裁方法的原理和使用方法。


一、原理

注:为了防止混淆,本文对神经网络中的参数称为“网络参数”,其他程序相关参数成为“参数”。

pytorch中梯度剪裁方法为 torch.nn.utils.clip_grad_norm_(parameters, max_norm, norm_type=2)1。三个参数:

parameters:希望实施梯度裁剪的可迭代网络参数
max_norm:该组网络参数梯度的范数上限
norm_type:范数类型

官方对该方法的描述为:

"Clips gradient norm of an iterable of parameters. The norm is computed over all gradients together, as if they were concatenated into a single vector. Gradients are modified in-place."

“对一组可迭代(网络)参数的梯度范数进行裁剪。效果如同将所有参数连接成单个向量来计算范数。梯度原位修改。”

我们来逐段分析其实现代码:

def clip_grad_norm_(parameters, max_norm, norm_type=2):
    if isinstance(parameters, torch.Tensor):
        parameters = [parameters]
    parameters = list(filter(lambda p: p.grad is not None, parameters))
    max_norm = float(max_norm)
    norm_type = float(norm_type)

该部分处理了传入的三个参数。首先将parameters中的非空网络参数存入一个列表,然后将max_normnorm_type类型强制为浮点数。

    if norm_type == inf:
        total_norm = max(p.grad.data.abs().max() for p in parameters)

该句对无穷范数进行了单独计算,即取所有网络参数梯度范数中的最大值,定义为total_norm
t o t a l _ n o r m ∞ = max ⁡ p i ∈ P ∣ g r a d ( p i ) ∣ {total\_norm}^{\infty}=\max_{pi\in {P}}|grad(p_i)| total_norm=piPmaxgrad(pi)

    else:
        total_norm = 0
        for p in parameters:
            param_norm = p.grad.data.norm(norm_type)
            total_norm += param_norm.item() ** norm_type
        total_norm = total_norm ** (1. / norm_type)

对于其他范数,我们计算所有网络参数梯度范数之和,再归一化,即等价于把所有网络参数放入一个向量,再对向量计算范数。将结果定义为total_norm
t o t a l _ n o r m n o r m _ t y p e = { ∑ p i ∈ P [ g r a d ( p i ) ] n o r m _ t y p e } 1 n o r m _ t y p e {total\_norm}^{norm\_type}=\{\sum_{pi\in {P}}[grad(p_i)]^{norm\_type}\}^{\frac{1}{norm\_type}} total_normnorm_type={ piP

  • 105
    点赞
  • 208
    收藏
    觉得还不错? 一键收藏
  • 14
    评论
评论 14
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值