旧版中 pytorch.rfft 函数与新版 pytorch.fft.rfft 函数对应修改问题

旧版中 pytorch.rfft 函数与新版 pytorch.fft.rfft 函数对应修改问题


前言

这两天整理谱池化操作,需要用到傅里叶变换这个函数。后来提升了pytorch的版本以后,发现之前的torch.rfft() 函数在新版的pytorch中使用会报错,后来查阅资料,发现是新版的参数有些变动。

pytorch旧版本(1.7之前)中有一个函数torch.rfft(),但是新版本(1.8、1.9)中被移除了,添加了torch.fft.rfft(),但它跟旧版的函数有了很大的变动,参数进行了一个大的调整。
傅里叶变换的整个过程我并没有搞的十分清晰,尤其是pytorch中的引用,网上对于这个函数解析的资料也十分有限,然后从知乎上参考了一篇文章,将我的问题解决了,感谢这位仁兄。


一、旧版 pytorch.rfft()函数解释

fft = torch.rfft(input, 2, normalized=True, onesided=False)
#  input 为输入的图片或者向量,dtype=torch.float32,size比如为[1,3,64,64]

参数说明:

input (Tensor) – the input tensor of at least signal_ndim dimensions
signal_ndim (int) – the number of dimensions in each signal. signal_ndim can only be 1, 2 or 3
normalized (bool, optional) – controls whether to return normalized results. Default: False
onesided (bool, optional) – controls whether to return half of results to avoid redundancy. Default: True
在上述的代码中,signal_ndim=2 因为图像是二维的,normalized=False 说明不进行归一化,onesided=False 则是希望不要减少最后一个维度的大小

在1.7版本torch.rfft中,有一个warning,表示在新版中,要“one-side ouput”的话用torch.fft.rfft(),要“two-side ouput”的话用torch.fft.fft()。这里的one/two side,跟旧版的onesided参数对应,所以我们要的是新版的torch.fft.fft()

需要注意的是,假设输入tensor的维度为 [ N 1 , N 2 , , , , N d ] [N_1,N_2,,,,N_d] [N1,N2,,,,Nd],则输出tensor的维度为 [ N 1 , N 2 , , , , N d , 2 ] [N_1,N_2,,,,N_d,2] [N1,N2,,,,Nd,2] 。最后一个维度2表示复数中的实部、虚部,即 z = a + b i z =a+bi z=a+bi这样的复数,在旧版pytorch中表示为一个二维向量 [ a , b ] [a,b] [a,b]

二、新版pytorch.fft.rfft()函数解释

新版官网解释

Getting started with the new torch.fft module is easy whether you are familiar with NumPy’s np.fft module or not. While complete documentation for each function in the module can be found here, a breakdown of what it offers is:

  • fft, which computes a complex FFT over a single dimension, and ifft, its inverse
  • the more general fftn and ifftn, which support multiple dimensions
  • The “real” FFT functions, rfft, irfft, rfftn, irfftn, designed to work with signals that are real-valued in their time domains
  • The “Hermitian” FFT functions, hfft and ihfft, designed to work with signals that are real-valued in their frequency domains
  • Helper functions, like fftfreq, rfftfreq, fftshift, ifftshift, that make it easier to manipulate signals

官网解释链接:https://pytorch.org/blog/the-torch.fft-module-accelerated-fast-fourier-transforms-with-autograd-in-pyTorch/

小结:可以看到这里也有rfft,官方文档说是用来处理都是实数的输入。但是它在前面的warning中说了是one-side,而我们要的是two-side。此外实数也可以看作是虚部都为0的复数,所以用fft没问题。
新版的rfft和fft都是用于一维输入,而我们的图像是二维,所以应该用rfft2和fft2。在fft2中,参数dim用来指定用于傅里叶变换的维度,默认(-2,-1),正好对应H、W两个维度。
新版所有的fft都不将复数 z = a + b j z=a+bj z=a+bj 存成二维向量了,而是一个数 [ z = a + b j ] [z=a+bj] [z=a+bj]。所以如果要跟旧版中一样存成二维向量,需要用.real()和.imag()提取复数的实部和虚部,然后用torch.stack()堆到一起,即可。


三、总结

代码变更对比如下:

import torch
input = torch.rand(1,3,32,32)

# 旧版pytorch.rfft()函数
fft = torch.rfft(input, 2, normalized=True, onesided=False)

# 新版 pytorch.fft.rfft2()函数
output = torch.fft.fft2(input, dim=(-2, -1))
output = torch.stack((output.real, output_new.imag), -1)

以上是我的理解,整体理解参考文章如下连接。

知乎:旧版pytorch中torch.rfft在新版本中的对应

  • 10
    点赞
  • 35
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值