1. 傅里叶变换提取高频和低频【有损】
- 环境:集群210.30.98.11
- 效果:

2. 傅里叶变换提取振幅和相位【无损】
- 环境:集群210.30.98.11
- 效果:

3. 小波变换【不涉及恢复代码】
- 环境:集群210.30.98.11
- 效果:

代码1.
import torchvision.transforms as T
from PIL import Image
import torch
import matplotlib.pyplot as plt
import os
def extract_frequency_components(image, cutoff_ratio=0):
image = image.float()
C, H, W = image.shape
fft_image = torch.fft.fftshift(torch.fft.fft2(image, dim=(-2, -1)), dim=(-2, -1))
center_x, center_y = H // 2, W // 2
cutoff_x, cutoff_y = int(cutoff_ratio * H), int(cutoff_ratio * W)
mask = torch.zeros_like(fft_image)
mask[:, center_x - cutoff_x:center_x + cutoff_x, center_y - cutoff_y:center_y + cutoff_y] = 1
low_freq_fft = fft_image * mask
high_freq_fft = fft_image * (1 - mask)
low_freq = torch.abs(torch.fft.ifft2(torch.fft.ifftshift(low_freq_fft, dim=(-2, -1)), dim=(-2, -1)))
high_freq = torch.abs(torch.fft.ifft2(torch.fft.ifftshift(high_freq_fft, dim=(-2, -1)), dim=(-2, -1)))
return low_freq, high_freq
def recover_image(low_freq, high_freq):
return low_freq + high_freq
def save_image(tensor_image, output_path):
tensor_image = tensor_image.clamp(0, 1)
to_pil = T.ToPILImage()
pil_image = to_pil(tensor_image)
pil_image.save(output_path)
def process_and_save_images(image_path, output_dir, cutoff_ratio=0.1):
os.makedirs(output_dir, exist_ok=True)
image = Image.open(image_path).convert('RGB')
transform = T.Compose([
T.Resize((256, 256)),
T.ToTensor()
])
image_tensor = transform(image)
low_freq, high_freq = extract_frequency_components(image_tensor, cutoff_ratio)
recovered_image = recover_image(low_freq, high_freq)
original_image_np = image_tensor.permute(1, 2, 0).numpy()
low_freq_np = low_freq.permute(1, 2, 0).numpy()
high_freq_np = high_freq.permute(1, 2, 0).numpy()
recovered_image_np = recovered_image.permute(1, 2, 0).numpy()
fig, axes = plt.subplots(2, 2, figsize=(8, 8))
axes[0, 0].imshow(original_image_np)
axes[0, 0].set_title('Original Image')
axes[0, 0].axis('off')
axes[0, 1].imshow(low_freq_np)
axes[0, 1].set_title('Low Frequency Image')
axes[0, 1].axis('off')
axes[1, 0].imshow(high_freq_np)
axes[1, 0].set_title('High Frequency Image')
axes[1, 0].axis('off')
axes[1, 1].imshow(recovered_image_np)
axes[1, 1].set_title('Recovered Image')
axes[1, 1].axis('off')
plt.tight_layout()
combined_image_path = os.path.join(output_dir, "combined_images2.png")
plt.savefig(combined_image_path)
plt.close(fig)
print(f"Processed images saved to {output_dir}")
image_path = "lena.png"
output_dir = "output_images"
process_and_save_images(image_path, output_dir, cutoff_ratio=0.1)
代码2
import torchvision.transforms as T
from PIL import Image
import torch
import matplotlib.pyplot as plt
import os
def extract_frequency_components(image):
fft_image = torch.fft.fftn(image, dim=(-2, -1))
fft_shifted = torch.fft.fftshift(fft_image)
magnitude = torch.abs(fft_shifted)
phase = torch.angle(fft_shifted)
return magnitude, phase
def recover_image(magnitude, phase):
real = magnitude * torch.cos(phase)
imag = magnitude * torch.sin(phase)
complex_freq = torch.complex(real, imag)
complex_freq_shifted = torch.fft.ifftshift(complex_freq)
recovered_image = torch.fft.ifftn(complex_freq_shifted, dim=(-2, -1))
return recovered_image.real
def save_image(tensor_image, output_path):
tensor_image = tensor_image.clamp(0, 1)
to_pil = T.ToPILImage()
pil_image = to_pil(tensor_image)
pil_image.save(output_path)
def process_and_save_images(image_path, output_dir, cutoff_ratio=0.1):
os.makedirs(output_dir, exist_ok=True)
image = Image.open(image_path).convert('RGB')
transform = T.Compose([
T.Resize((256, 256)),
T.ToTensor()
])
image_tensor = transform(image)
low_freq, high_freq = extract_frequency_components(image_tensor)
recovered_image = recover_image(low_freq, high_freq)
original_image_np = image_tensor.permute(1, 2, 0).numpy()
low_freq_np = low_freq.permute(1, 2, 0).numpy()
high_freq_np = high_freq.permute(1, 2, 0).numpy()
recovered_image_np = recovered_image.permute(1, 2, 0).numpy()
fig, axes = plt.subplots(2, 2, figsize=(8, 8))
axes[0, 0].imshow(original_image_np)
axes[0, 0].set_title('Original Image')
axes[0, 0].axis('off')
axes[0, 1].imshow(low_freq_np)
axes[0, 1].set_title('Low Frequency Image')
axes[0, 1].axis('off')
axes[1, 0].imshow(high_freq_np)
axes[1, 0].set_title('High Frequency Image')
axes[1, 0].axis('off')
axes[1, 1].imshow(recovered_image_np)
axes[1, 1].set_title('Recovered Image')
axes[1, 1].axis('off')
plt.tight_layout()
combined_image_path = os.path.join(output_dir, "combined_images3.png")
plt.savefig(combined_image_path)
plt.close(fig)
print(f"Processed images saved to {output_dir}")
image_path = "lena.png"
output_dir = "output_images"
process_and_save_images(image_path, output_dir, cutoff_ratio=0.1)片
代码3
import torch
import torchvision.transforms as T
import pywt
import matplotlib.pyplot as plt
import os
from PIL import Image
output_folder = "wavelet_images"
os.makedirs(output_folder, exist_ok=True)
image = Image.open('lena.png').convert('RGB')
transform = T.Compose([T.ToTensor()])
image_tensor = transform(image)
num_channels = image_tensor.shape[0]
wavelet = 'db4'
subplots_per_row = 4
total_subplots = num_channels * 4
num_rows = num_channels
num_cols = 4
print(f"num_rows: {num_rows}, num_cols: {num_cols}")
fig, axes = plt.subplots(num_rows, num_cols, figsize=(16, num_rows * 4))
print(f"axes shape: {axes.shape}")
subplot_index = 0
for channel in range(num_channels):
channel_image = image_tensor[channel, :, :].numpy()
coeffs = pywt.dwt2(channel_image, wavelet)
cA, (cH, cV, cD) = coeffs
cA_tensor = torch.from_numpy(cA).float() / 255.0
cH_tensor = torch.from_numpy(cH).float() / 255.0
cV_tensor = torch.from_numpy(cV).float() / 255.0
cD_tensor = torch.from_numpy(cD).float() / 255.0
if num_rows == 1:
current_axis = axes[subplot_index % num_cols]
else:
current_axis = axes[subplot_index // num_cols, subplot_index % num_cols]
current_axis.imshow(cA_tensor.numpy(), cmap='gray')
current_axis.set_title(
f"Channel {channel} - (LL)")
current_axis.axis('off')
subplot_index += 1
current_axis = None
if num_rows == 1:
current_axis = axes[subplot_index % num_cols]
else:
current_axis = axes[subplot_index // num_cols, subplot_index % num_cols]
current_axis.imshow(cH_tensor.numpy(), cmap='gray')
current_axis.set_title(
f"Channel {channel} - (HL)")
current_axis.axis('off')
subplot_index += 1
current_axis = None
if num_rows == 1:
current_axis = axes[subplot_index % num_cols]
else:
current_axis = axes[subplot_index // num_cols, subplot_index % num_cols]
current_axis.imshow(cV_tensor.numpy(), cmap='gray')
current_axis.set_title(
f"Channel {channel} - (LH)")
current_axis.axis('off')
subplot_index += 1
current_axis = None
if num_rows == 1:
current_axis = axes[subplot_index % num_cols]
else:
current_axis = axes[subplot_index // num_cols, subplot_index % num_cols]
current_axis.imshow(cD_tensor.numpy(), cmap='gray')
current_axis.set_title(
f"Channel {channel} - (HH)")
current_axis.axis('off')
subplot_index += 1
plt.tight_layout()
output_path = os.path.join(output_folder, "wavelet_decomposition_color.png")
plt.savefig(output_path)
plt.close(fig)