自定义后向传播
forward() 网络层的计算操作,能够根据你的需要设置参数。
backward() 梯度计算操作。
继承Function
class LinearFunction(Function):
# Note that both forward and backward are @staticmethods
@staticmethod
# bias is an optional argument
def forward(ctx, input, weight, bias=None):
ctx.save_for_backward(input, weight, bias)
output = input.mm(weight.t())
if bias is not None:
output += bias.unsqueeze(0).expand_as(output)
return output
# This function has only a single output, so it gets only one gradient
@staticmethod
def backward(ctx, grad_output):
# This is a pattern that is very convenient - at the top of backward
# unpack saved_tensors and initialize all gradients w.r.t. inputs to
# None. Thanks to the fact that additional trailing Nones are
# ignored, the return statement is simple even when the function has
# optional inputs.
input, weight, bias = ctx.saved_variables
grad_input = grad_weight = grad_bias = None
# These needs_input_grad checks are optional and there only to
# improve efficiency. If you want to make your code simpler, you can
# skip them. Returning gradients for inputs that don't require it is
# not an error.
if ctx.needs_input_grad[0]:
grad_input = grad_output.mm(weight)
if ctx.needs_input_grad[1]:
grad_weight = grad_output.t().mm(input)
if bias is not None and ctx.needs_input_grad[2]:
grad_bias = grad_output.sum(0).squeeze(0)
return grad_input, grad_weight, grad_bias
有多少个输入,就要返回对应个输入的梯度输出,即grad_input。
权值的梯度一定要返回,返回顺序要与输入保持一致,即grad_input_n、···、grad_input_n,grad_weight和grad_bias。
梯度测试
from torch.autograd import gradcheck
input = (Variable(torch.randn(20,20).double(), requires_grad=True), Variable(torch.randn(30,20).double(), requires_grad=True),)
test = gradcheck(Linear.apply, input, eps=1e-6, atol=1e-4)
print(test)
添加Module
拓展torch.nn时,我们需要创建一个新的Module。
class Linear(nn.Module):
def __init__(self, input_features, output_features, bias=True):
super(Linear, self).__init__()
self.input_features = input_features
self.output_features = output_features
# nn.Parameter 是自动在Module中注册的变量,其它变量需要注册时使用.register_buffer()方法,否则会出问题。
self.weight = nn.Parameter(torch.Tensor(output_features, input_features))
if bias:
self.bias = nn.Parameter(torch.Tensor(output_features))
else:
# 除了可选参数外,所有参数都需要注册。
self.register_parameter('bias', None)
# 简单的参数初始化方法
self.weight.data.uniform_(-0.1, 0.1)
if bias is not None:
self.bias.data.uniform_(-0.1, 0.1)
def forward(self, input):
# 调用实现函数
return LinearFunction.apply(input, self.weight, self.bias)