Normally, pytorch can automatically achieve the gradient computation, i.e., autograd.
However, if we define a loss function including some complex operations, the autograd mechanics don't work.
Thus, adding these operations to autograd requires implementing a new Function subclass for each operation. Recall that Function s are what autograd uses to compute the results and gradients, and encode the operation history.
It can be achieved by using “Extending torch.autograd ”:
(1)__init__ (optional) - 如果你的operation包含非Variable参数,那么就将其作为__init__的参数传入到operation中。例如:AddConstant Function加一个常数,Transpose Function需要指定哪两个维度需要交换。如果你的operation不需要额外的参数,你可以忽略__init__。
(2)forward() - 计算 op 的前向过程
- 在执行 forward 之前,Variable 参数已经被转换成了 Tensor
- forward 的形参可以有默认参数,默认参数可以是