高版本torch(1.7及以后)中移除了torch.rfft与torch.irfft,使得高版本环境无法运行低版本的代码。以下是解决办法:
torch.rfft(t, 2, onesided=False)
改为本地实现代码:
def rfft(t):
# Real-to-complex Discrete Fourier Transform
x = torch.fft.fft2(t, dim=(-2, -1))
return torch.stack((x.real, x.imag), -1)
torch.irfft(t, 2, onesided=False)
替换为:
def irfft(t):
# Complex-to-real Inverse Discrete Fourier Transform
return torch.fft.ifft2(torch.complex(t[..., 0], t[..., 1])).real