总说
虽然pytorch可以自动求导,但是有时候一些操作是不可导的,这时候你需要自定义求导方式。也就是所谓的 “Extending torch.autograd”. 官网虽然给了例子,但是很简单。这里将会更好的说明。
#扩展 torch.autograd
class LinearFunction(Function):
# 必须是staticmethod
@staticmethod
# 第一个是ctx,第二个是input,其他是可选参数。
# ctx在这里类似self,ctx的属性可以在backward中调用。
def forward(ctx, input, weight, bias=None):
ctx.save_for_backward(input, weight, bias)
output = input.mm(weight.t())
if bias is not None:
output += bias.unsqueeze(0).expand_as(output)
return output
@staticmethod
def backward(ctx, grad_output):
input, weight, bias = ctx.</