报错:网络训练过程中出现nan值
原因:函数并不是在所有点可导,增加小值避免反向传播出现nan值的情况。
from torch.autograd import Function
from torch import Tensor
class angle(Function):
"""Similar to torch.angle but robustify the gradient for zero magnitude."""
@staticmethod
def forward(ctx, x: Tensor):
ctx.save_for_backward(x)
return torch.atan2(x.imag, x.real)
@staticmethod
def backward(ctx, grad: Tensor):
(x,) = ctx.saved_tensors
grad_inv = grad / (x.real.square() + x.imag.square()).clamp_min_(1e-10)
return torch.view_as_complex(torch.stack((-x.imag * grad_inv, x.real * grad_inv), dim=-1))
ref:GitHub - Rikorose/DeepFilterNet: Noise supression using deep filtering