import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os
import numpy as np
import torch.nn.functional as F
import matplotlib.pyplot as plt
# 定义 U-Net 模型
class UNet(nn.Module):
def __init__(self):
super(UNet, self).__init__()
# 编码器
self.encoder = nn.Sequential(
nn.Conv2d(1, 64, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(2)
)
self.middle = nn.Sequential(
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(128, 128, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(2)
)
# 解码器
self.decoder = nn.Sequential(
nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(64, 1, kernel_size=3, padding=1)
)
def forward(self, x):
x1 = self.encoder(x)
x2 = self.middle(x1)
x3 = self.decoder(x2)
# 强制调整输出的尺寸为目标图像尺寸(28, 28)对于MNIST
x3 = F.interpolate(x3, size=(28, 28), mode='bilinear', align_corners=False)
return x3
# 自定义数据集类
class MNISTDataset(Dataset):
def __init__(self, images_folder, transform=None):
self.images_folder = images_folder
self.transform = transform
# 获取所有图片路径
self.image_files = [f for f in os.listdir(images_folder) if f.endswith('.jpg')]
def __len__(self):
return len(self.image_files)
def __getitem__(self, idx):
# 获取图像路径
img_name = os.path.join(self.images_folder, self.image_files[idx])
img = Image.open(img_name).convert('L') # 转换为灰度图
# 转换为numpy数组并归一化
img = np.array(img).astype(np.float32) / 255.0 # 归一化到[0, 1]
img = np.expand_dims(img, axis=-1) # 增加一个维度使其成为 (28, 28, 1)
img = torch.tensor(img).permute(2, 0, 1) # 转换为 (1, 28, 28)
# 添加噪声
noisy_img = self.add_noise(img)
return noisy_img, img
def add_noise(self, img):
noise = torch.randn_like(img) * 0.2
noisy_img = img + noise
noisy_img = torch.clamp(noisy_img, 0.0, 1.0)
return noisy_img
# 数据预处理与加载
transform = None # 这里你可以根据需要添加额外的变换
# 定义JPEG图像文件夹路径
images_folder = 'D:/APP/QQ/download/相关代码/图片修复/mnist—train/mnist_jpg' # 修改为JPEG文件所在的文件夹路径
# 加载数据集
dataset = MNISTDataset(images_folder, transform=transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
# 初始化 U-Net 模型
model = UNet().cuda() # 如果有GPU,移动到GPU
criterion = nn.MSELoss() # 使用均方误差损失函数
optimizer = optim.Adam(model.parameters(), lr=1e-3)
# 训练循环
num_epochs = 10
for epoch in range(num_epochs):
model.train()
running_loss = 0.0
for noisy_images, clean_images in dataloader:
noisy_images, clean_images = noisy_images.cuda(), clean_images.cuda()
# 前向传播
outputs = model(noisy_images)
# 计算损失
loss = criterion(outputs, clean_images)
# 反向传播和优化
optimizer.zero_grad()
loss.backward()
optimizer.step()
running_loss += loss.item()
print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss / len(dataloader):.4f}")
# 测试与结果展示
model.eval()
with torch.no_grad():
noisy_image, clean_image = dataset[0] # 获取一个样本
noisy_image, clean_image = noisy_image.cuda(), clean_image.cuda()
denoised_image = model(noisy_image.unsqueeze(0)) # 批量大小为1时前向传播
# 显示图像
plt.figure(figsize=(12, 4))
# 显示原始图像
plt.subplot(1, 3, 1)
plt.imshow(clean_image.cpu().squeeze(), cmap='gray')
plt.title('Clean Image')
plt.axis('off')
# 显示噪声图像
plt.subplot(1, 3, 2)
plt.imshow(noisy_image.cpu().squeeze(), cmap='gray')
plt.title('Noisy Image')
plt.axis('off')
# 显示修复后的图像
plt.subplot(1, 3, 3)
plt.imshow(denoised_image.cpu().squeeze(), cmap='gray')
plt.title('Denoised Image')
plt.axis('off')
plt.show()
根据以上代码告诉我基于unet的老照片修复项目,模型的训练计划是什么