class Sigmoid(Function):
@staticmethod
def forward(ctx, x):
output = 1/(1 + t.exp(-x))
ctx.save_for_backward(output)
return output
@staticmethond
def backward(ctx, grad_output):
output, = ctx.saved_variables
grad_x = output * (1 - output) * grad_output
return grad_x
def f_sigmoid(x):
y = Sigmoid.apply(x)
y.backward(t.ones(x.size()))
the backward part in f_sigmoid function has optimized the process of backward