参考https://blog.csdn.net/winycg/article/details/104410525
https://tissue333.gitbook.io/cornell/findings/pytorch_backward
import torch
class RoundNoGradient(torch.autograd.Function):
@staticmethod
def forward(x):
return x.round()
@staticmethod
def backward(g):
return g
m = torch.nn.Conv2d(16, 10, 3, stride=2)
l = torch.nn.L1Loss()
input = torch.autograd.Variable(torch.randn(20, 16, 50, 10), requires_grad=True)
x = RoundNoGradient()(input)
y = m(x)
output = l(x, torch.autograd.Variable(torch.randn(20, 16, 50, 10)))
output.backward()
print(x.grad)