当我们在Pytorch中想自定义某一层的梯度计算时,可以利用torch.autograd.Function
来封装一个class,此时可以我们可以自己在backward
方法中自定求解梯度的方法,也适用于不可导函数的backward计算。
这个函数的源代码可以从如下链接获取:
https://pytorch.org/docs/stable/_modules/torch/autograd/function.html
首先给出一个官方提供的demo:
class Exp(torch.autograd.Function):
@staticmethod
def forward(ctx, i):
result = torch.exp(i)
ctx.save_for_backward(result)
return result
@staticmethod
def backward(ctx, grad_output):
result, = ctx.saved_tensors
return grad_output * result
这个函数的forward里计算了:
y
=
e
x
y=e^{x}
y=ex
backward里给定了
y
y
y的梯度(grad_output),计算
x
x
x的梯度:
g
r
a
d
_
o
u
t
p
u
t
∗
∂
y
∂
x
=
g
r
a
d
_
o
u
t
p
u
t
∗
e
x
grad\_output*\frac{\partial{y}}{\partial x}=grad\_output*e^{x}
grad_output∗∂x∂y=grad_output∗ex
- 首先介绍
forward
函数,此函数必须接受一个context ctx作为第一个参数,之后可以传入任何参数,ctx可以利用save_for_backward
来保存tensors,在backward阶段可以进行获取。forward里定义了前向传播的路径。 - 之后介绍
backward
函数,此函数必须接受一个context ctx作为第一个参数,然后是第二个参数grad_output,里面存储forward后tensor的梯度。return的结果是对应每一个input(对应于forward里的各个input)的梯度。ctx.needs_input_grad
作为一个boolean型的表示也可以用来控制每一个input是否需要计算梯度,e.g.,ctx.needs_input_grad[0] = False
,表示forward里的第一个input不需要梯度,若此时我们return时这个位置的梯度值表示为None
即可。
利用apply方法可以使用层:其中定义的对象像一个function一样在不同位置可以重复利用。
import torch
class Exp(torch.autograd.Function):
@staticmethod
def forward(ctx, i):
result = torch.exp(i)
ctx.save_for_backward(result)
return result
@staticmethod
def backward(ctx, grad_output):
result, = ctx.saved_tensors
return grad_output * result
exp = Exp()
x1 = torch.tensor([3., 4.], requires_grad=True)
x2 = exp.apply(x1)
y2= exp.apply(x2)
y2.sum().backward()
print(x1.grad)
再提供部分一个复杂一点的线性转换的例子:
y
=
x
W
T
+
b
y=xW^{T}+b
y=xWT+b
∂
y
∂
x
=
W
,
∂
y
∂
W
=
x
,
∂
y
∂
b
=
1
\frac{\partial{y}}{\partial x}=W,\frac{\partial{y}}{\partial W}=x,\frac{\partial{y}}{\partial b}=1
∂x∂y=W,∂W∂y=x,∂b∂y=1
这里定义了Linear模块:
import torch
import torch.nn as nn
class LinearFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, weight, bias=None):
ctx.save_for_backward(input, weight, bias)
output = input.mm(weight.t())
if bias is not None:
output += bias.unsqueeze(0).expand_as(output)
return output
@staticmethod
def backward(ctx, grad_output):
input, weight, bias = ctx.saved_tensors
grad_input = grad_weight = grad_bias = None
if ctx.needs_input_grad[0]:
grad_input = grad_output.mm(weight)
if ctx.needs_input_grad[1]:
grad_weight = grad_output.t().mm(input)
if bias is not None and ctx.needs_input_grad[2]:
grad_bias = grad_output.sum(0).squeeze(0)
return grad_input, grad_weight, grad_bias
class Linear(nn.Module):
def __init__(self, input_features, output_features, bias=True):
super(Linear, self).__init__()
self.input_features = input_features
self.output_features = output_features
self.weight = nn.Parameter(torch.Tensor(output_features, input_features))
if bias:
self.bias = nn.Parameter(torch.Tensor(output_features))
else:
self.register_parameter('bias', None)
self.weight.data.uniform_(-0.1, 0.1)
if bias is not None:
self.bias.data.uniform_(-0.1, 0.1)
def forward(self, input):
# See the autograd section for explanation of what happens here.
return LinearFunction.apply(input, self.weight, self.bias)
官方另外的例子,返回的梯度变量必须对应input变量:
>>> class Func(Function):
>>> @staticmethod
>>> def forward(ctx, x: torch.Tensor, y: torch.Tensor, z: int):
>>> w = x * z
>>> out = x * y + y * z + w * y
>>> ctx.save_for_backward(x, y, w, out)
>>> ctx.z = z # z is not a tensor
>>> return out
>>>
>>> @staticmethod
>>> @once_differentiable
>>> def backward(ctx, grad_out):
>>> x, y, w, out = ctx.saved_tensors
>>> z = ctx.z
>>> gx = grad_out * (y + y * z)
>>> gy = grad_out * (x + z + w)
>>> gz = None
>>> return gx, gy, gz
>>>
>>> a = torch.tensor(1., requires_grad=True, dtype=torch.double)
>>> b = torch.tensor(2., requires_grad=True, dtype=torch.double)
>>> c = 4
>>> d = Func.apply(a, b, c)
计算交叉熵的梯度:
logits = torch.tensor([[0.2, 0.3, 0.9],
[0.2, 0.3, 0.9],], requires_grad=True)
targets = torch.tensor([0, 0])
class CrossEntropyLoss(torch.autograd.Function):
@staticmethod
def forward(ctx, logits, targets):
targets = F.one_hot(targets, num_classes=logits.size(1)).float()
prob = F.softmax(logits, 1)
ctx.save_for_backward(prob, targets)
logits = F.log_softmax(logits, 1)
loss = -(targets * logits).sum(1).mean()
return loss
@staticmethod
def backward(ctx, grad_output):
prob, targets = ctx.saved_tensors
grad_logits = (grad_output * (prob - targets)) / targets.size(0)
grad_targets = None
return grad_logits, grad_targets
auto_loss = CrossEntropyLoss.apply(x, target)
auto_loss.backward()
print(logits.grad)