第一张图为功率谱,第二张图为相位谱,第三张图为反变换结果
全部代码如下
import torch
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
# norm='backward' 'ortho'
def fft(image, shift=True, norm='backward', rfft=False):
r'''
image:[b,c,h,w]
return:[b,c,h,w],[b,c,h,w]
norm = "backward", "forward", "ortho", None
'''
if rfft:
fft_image = torch.fft.rfft2(image, norm=norm)
else:
fft_image = torch.fft.fft2(image, norm=norm)
if shift and not rfft:
fft_image = torch.fft.fftshift(fft_image + 1e-8)
return torch.abs(fft_image), torch.angle(fft_image)
def ifft(magnitude, phase, shift=True, norm='backward', rfft=False):
r'''
magnitude:[b,c,h,w],phase:[b,c,h,w]
return:[b,c,h,w]
norm = "backward", "forward", "ortho", None
'''
if shift and not rfft:
magnitude = torch.fft.ifftshift(magnitude)
phase = torch.fft.ifftshift(phase)
complex_spectrum = torch.complex(
magnitude * torch.cos(phase) + 1e-8, magnitude * torch.sin(phase) + 1e-8) + 1e-8
return torch.fft.irfft2(complex_spectrum, norm=norm).real if rfft else torch.fft.ifft2(complex_spectrum,
norm=norm).real
if __name__ == '__main__':
# read image
print("hello")
image = Image.open('RGB_tif_R47C9.tif')
image = torch.from_numpy(np.array(image)) / 255
image = torch.unsqueeze(image.permute(2, 0, 1), 0)
mag_image, pha_image = fft(image, shift=True)
image2 = ifft(mag_image, pha_image, shift=True)
#
plt.subplot(1, 3, 1), plt.imshow(torch.log(1 + mag_image[0][0]))
plt.subplot(1, 3, 2), plt.imshow(pha_image[0][0])
plt.subplot(1, 3, 3), plt.imshow(
image2[0].permute(1, 2, 0).clamp(0., 1.).numpy())
plt.show()