神经网络是通过梯度下降来进行网络学习,随着网络层数的增加,"梯度爆炸"的问题可能会越来越明显。例如:在梯度反向传播中,如果每一层的输出相对输入的偏导 > 1,随着网络层数的增加,梯度会越来越大,则有可能发生 "梯度爆炸"。
如果发生了 "梯度爆炸",在网络学习过程中会直接跳过最优解,所以有必要进行梯度裁剪,防止网络在学习过程中越过最优解。
本文使用numpy实现了3种梯度裁剪策略,并与pytorch/paddle实现的方法进行对比。
目录
一、clip by value
设置梯度范围为[-value, value],将小于-value的梯度值设置为-value,大于value的梯度值设置为value,公式如下:
numpy实现及torch对比如下:
import torch
import numpy as np
def clip_grad_by_value(grad, value):
grad[grad < -value] = -value
grad[grad > value] = value
return grad
np.random.seed(10001)
grad = np.random.rand(4, 100) * 3 - 1.5
value = 1.0
print('grad before clip, min: {}, max: {}'.format(grad.min(), grad.max()))
np_cliped_grad = clip_grad_by_value(grad, value)
torch_grad = torch.from_numpy(grad)
torch.nn.utils.clip_grad_value_(torch_grad, value)
print('grad after clip, min: {}, max: {}'.format(torch_grad.numpy().min(), torch_grad.numpy().max()))
print('numpy grad and torch grad error: ', np.sum(np_cliped_grad - torch_grad.numpy()))
assert np.allclose(np_cliped_grad, torch_grad.numpy()), "clipped grad not equal"
# 输出:
# grad before clip, min: -1.4982050803275673, max: 1.4785602913399787
# grad after clip, min: -1.0, max: 1.0
# numpy grad and torch grad error: 0.0
二、Clip by global norm
通过范数裁剪梯度的函数是torch.nn.utils.
clip_grad_norm_,该方法将所有梯度的范数和限定在一定范围内。
numpy实现及torch对比如下(L2范数):
import torch
import numpy as np
def clip_grad_by_global_norm(grad, max_norm, norm_type=2):
global_norm = np.linalg.norm(
np.stack([np.linalg.norm(g.ravel(), ord=norm_type) for g in grad]).ravel(),
ord=norm_type,
)
clip_coef = max_norm / (global_norm + 1e-6)
if clip_coef > 1:
clip_coef = 1
normed_grad = []
for g in grad:
g = g * clip_coef
normed_grad.append(g)
return normed_grad
def clip_grad_norm_(
parameters, max_norm: float, norm_type: float = 2.0) -> torch.Tensor:
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
parameters = [p for p in parameters]
max_norm = float(max_norm)
norm_type = float(norm_type)
total_norm = torch.norm(torch.stack([torch.norm(p, norm_type) for p in parameters]), norm_type)
clip_coef = max_norm / (total_norm + 1e-6)
clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
for p in parameters:
p.mul_(clip_coef_clamped)
return total_norm
np.random.seed(10001)
grad1 = np.random.rand(2, 40) * 3 - 1.5
grad2 = np.random.rand(2, 40) * 2 - 1
grad = [grad1, grad2]
max_norm = 1.0
norm_type = 2
np_cliped_grad = clip_grad_by_global_norm(grad, max_norm)
torch_grad = [torch.from_numpy(g) for g in grad]
total_norm = clip_grad_norm_(torch_grad, max_norm=max_norm, norm_type=norm_type)
for idx, (np_grad, torch_grad) in enumerate(zip(np_cliped_grad, torch_grad)):
print("Number {} of grad in grad list error: {}.".format(idx, np.sum(np.abs(np_grad - torch_grad.numpy()))))
assert np.allclose(np_grad, torch_grad.numpy()), "clipped grad not equal"
# 输出:
# Number 0 of grad in grad list error: 7.707593244199451e-16.
# Number 1 of grad in grad list error: 4.758021233952636e-16.
PS:clip_grad_norm_函数是根据
torch.nn.utils.clip_grad_norm_函数修改来的,因为torch.nn.utils.clip_grad_norm_计算式需要Tensor的梯度进行计算,而生成参数的时候是Tensor的data,所以这里稍微改动了一下。
三、Clip by norm
torch.nn.utils.clip_grad_norm_的范数计算是全局所有梯度的范数,在paddlepaddle中还有一种方案,即针对某个tensor的梯度计算范数并裁剪,对应的类为
paddle.nn.ClipGradByNorm,公式如下:
表示tensor中单个参数值对应的梯度
numpy实现及paddle.nn.ClipGradByNorm对比如下:
import paddle
import numpy as np
def clip_grad_by_norm(test_data, clip_norm):
cliped_data = []
for data, grad in test_data:
norm = np.sqrt(np.sum(np.square(np.array(grad))))
if norm > clip_norm:
grad = grad * clip_norm / norm
cliped_data.append((data, grad))
return cliped_data
np.random.seed(10001)
data = (np.random.rand(2, 40) * 3 - 1.5).astype('float32')
grad = (np.random.rand(2, 40) * 2 - 1).astype('float32')
data = [(data, grad)]
clip_norm = 1.0
np_cliped_data = clip_grad_by_norm(data, clip_norm)
paddle_data = [(paddle.to_tensor(d[0]), paddle.to_tensor(d[1])) for d in data]
clip_layer = paddle.nn.ClipGradByNorm(clip_norm=clip_norm)
clipped_data = clip_layer(paddle_data)
for idx, (np_grad, paddle_g) in enumerate(zip(np_cliped_data, clipped_data)):
print("Number {} of grad in grad list error: {}.".format(idx, np.sum(np.abs(np_grad[1] - paddle_g[1].numpy()))))
assert np.allclose(np_grad[1], paddle_g[1].numpy()), "clipped grad not equal"
# 输出
# Number 0 of grad in grad list error: 1.6816193237900734e-07.