原因分析
源代码是用的torch1.7以前版本写的,自己环境是torch1.7以上,原来torch.rfft 和torch.rfft 被torch 升级删掉了,调用不了。
解决方法
不用重新安装环境,降torch版本!!!!!
改函数接口即可
搜索使用torch.rfft 和torch.irfft 的地方
在该文件最前面加上以下代码。同时,将torch.rfft 改为 rfft ,将torch.irfft 改为irfft
try:
from torch import irfft
from torch import rfft
except ImportError:
from torch.fft import irfft2
from torch.fft import rfft2
def rfft(x, d):
t = rfft2(x, dim = (-d,-1))
return torch.stack((t.real, t.imag), -1)
def irfft(x, d, signal_sizes):
return irfft2(torch.complex(x[:,:,0], x[:,:,1]), s = signal_sizes, dim = (-d,-1))