今天自定义不可导函数的时候遇到了一个大坑。
首先我需要自定义一个函数:sign_f
import torch
from torch.autograd import Function
import torch.nn as nn
class sign_f(Function):
@staticmethod
def forward(ctx, inputs):
output = inputs.new(inputs.size())
output[inputs >= 0.] = 1
output[inputs < 0.] = -1
ctx.save_for_backward(inputs)
return output
@staticmethod
def backward(ctx, grad_output):
input_, = ctx.saved_tensors
grad_output[input_>1.] = 0
grad_output[input_<-1.] = 0
return grad_output
然后我需要把它封装为一个module 类型,就像 nn.Conv2d 模块 封装 f.conv2d 一样,于是
import torch
from torch.autograd import Function
import torch.nn as nn
class sign_(nn.Module):
# 我需要的module
def __init__(self, *kargs, **kwargs):
super(sign_, self).__init__(*kargs, **kwargs)
def forward(self, inputs):
# 使用自定义函数
outs = sign_f(inputs)
return outs
class sign_f(Function):
@staticmethod
def forward(ctx, inputs):
output = inputs.new(inputs.size())
output[inputs >= 0.] = 1
output[inputs < 0.] = -1
ctx.save_for_backward(inputs)
return output
@staticmethod
def backward(ctx, grad_output):
input_, = ctx.saved_tensors
grad_output[input_>1.] = 0
grad_output[input_<-1.] = 0
return grad_output
结果报错
TypeError: backward() missing 2 required positional arguments: 'ctx' and 'grad_output'
我试了半天,发现自定义函数后面要加 apply ,详细见下面
import torch
from torch.autograd import Function
import torch.nn as nn
class sign_(nn.Module):
def __init__(self, *kargs, **kwargs):
super(sign_, self).__init__(*kargs, **kwargs)
self.r = sign_f.apply ### <-----注意此处
def forward(self, inputs):
outs = self.r(inputs)
return outs
class sign_f(Function):
@staticmethod
def forward(ctx, inputs):
output = inputs.new(inputs.size())
output[inputs >= 0.] = 1
output[inputs < 0.] = -1
ctx.save_for_backward(inputs)
return output
@staticmethod
def backward(ctx, grad_output):
input_, = ctx.saved_tensors
grad_output[input_>1.] = 0
grad_output[input_<-1.] = 0
return grad_output
问题解决了