import os
import torch
import torch.nn as nn
import torch.optim as optim
from matplotlib import pyplot as plt
from torchvision import datasets, transforms
from PIL import Image
# 检查是否安装了CUDA,并且CUDA是否适用于你的GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 超参数设置
batch_size = 16
learning_rate = 1e-3
num_epochs = 1
# 加载MNIST手写数字数据集
train_dataset = datasets.MNIST(root='./data/', train=True, transform=transforms.ToTensor(), download=True)
test_dataset = datasets.MNIST(root='./data/', train=False, transform=transforms.ToTensor())
# 创建数据加载器
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)
#shuffle参数是用来指定在每个epoch开始时是否对数据进行洗牌(随机排序)。设置为True时,数据加载器会在每个epoch开始时重新随机排列数据;设置为False时,数据会按照其原始顺序加载。
# 定义一个自编码器的类
class Autoencoder(nn.Module):
def __init__(self):
super(Autoencoder, self).__init__()
# 编码器部分
self.encoder = nn.Sequential(
nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1), # 输入1通道,输出16通道,3x3卷积,步长为2,padding为1
nn.ReLU(),
nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1), # 输入16通道,输出32通道,3x3卷积,步长为2,padding为1
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=7), # 输入32通道,输出64通道,7x7卷积
)
# 解码器部分
self.decoder = nn.Sequential(
nn.ConvTranspose2d(64, 32, kernel_size=7), # 输入64通道,输出32通道,7x7卷积转置
nn.ReLU(),
nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1),
# 输入32通道,输出16通道,3x3卷积转置,步长为2,padding为1,输出padding为1
nn.ReLU(),
nn.ConvTranspose2d(16, 1, kernel_size=3, stride=2, padding=1, output_padding=1),
# 输入16通道,输出1通道,3x3卷积转置,步长为2,padding为1,输出padding为1
nn.Sigmoid()
)
def forward(self, x):
x = self.encoder(x)
x = self.decoder(x)
return x
# 创建自编码器对象并将模型移动到GPU
model = Autoencoder().to(device)
# 定义损失函数和优化器
criterion = nn.MSELoss()
criterion1 = nn.L1Loss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
# 训练自编码器
for epoch in range(num_epochs):
for data in train_loader:
img, _ = data
#其中 data 是一个变量或者对象,img 是在这行代码中被赋值的一个变量。
# 通常情况下,_ 是一个占位符,表示在这里我们不需要使用这个变量的值。
# 因此,这行代码的作用是将 data 中的某些内容赋值给 img,而不需要使用 data 的其他部分。
img = img.to(device)
# 前向传播
output = model(img)
loss = criterion(output, img)+criterion1(output, img)
# 反向传播和优化器优化
optimizer.zero_grad()
#清空梯度的操作。在深度学习中,通常在每一轮参数更新之前调用该方法,以确保不同批次的梯度不会累积。
loss.backward()
#计算模型参数的梯度,以便进行优化更新。
optimizer.step()
#调用 optimizer.step() 就是告诉优化器执行一步参数更新,以减小损失函数的值。
print("Epoch[{}/{}], loss:{:.4f}".format(epoch+1, num_epochs, loss.data))
# 迭代测试数据集,生成迭代器
dataiter = iter(test_loader)
# 从迭代器中获取下一个批次的图像和标签
images, labels = next(dataiter)
# 使用模型进行推断,处理获取的图像数据,并将结果保存在output变量中
output = model(images.to(device))
# 创建子图和轴对象,其中第一行显示原始图像,第二行显示重构后的图像
fig, axes = plt.subplots(nrows=2, ncols=10, sharex=True, sharey=True, figsize=(25,4))
# 循环遍历前10个图像,绘制原始图像和重构图像并添加标题
for i in range(10):
# 显示原始图像
axes[0,i].imshow(images[i].squeeze().numpy(), cmap='gray')
axes[0,i].set_title("Original")
axes[0,i].get_xaxis().set_visible(False)
axes[0,i].get_yaxis().set_visible(False)
# 显示重构后的图像
axes[1,i].imshow(output[i].squeeze().cpu().detach().numpy(), cmap='gray')
axes[1,i].set_title("Reconstructed")
axes[1,i].get_xaxis().set_visible(False)
axes[1,i].get_yaxis().set_visible(False)
plt.savefig(os.path.join('results', f'zzimage_{i}.png'))
# 显示生成的子图
plt.show()
folder_path = "results"
os.makedirs(folder_path, exist_ok=True)
# Save each image in the folder
for i, output_image in enumerate(output):
plt.imshow(output_image.squeeze().cpu().detach().numpy(), cmap='gray')
plt.savefig(os.path.join(folder_path, f"image_{i}.png"))
plt.close()
folder_path = "results1/"
os.makedirs(folder_path, exist_ok=True)
# Save each image in the folder
for i, output_image in enumerate(output):
tensor_denormalized = output_image * 255
# 将张量四舍五入到最接近的整数
tensor_denormalized = torch.round(tensor_denormalized)
# 将张量转换为整数类型
tensor_denormalized = tensor_denormalized.type(torch.uint8)
output_image_np = tensor_denormalized.squeeze().detach().cpu().numpy() # Convert PyTorch tensor to numpy array
image = Image.fromarray(output_image_np) # Create PIL image from numpy array
image = image.convert("RGB")
image.save(folder_path+ f"image_{i}.png")
手写数字重构代码ED
最新推荐文章于 2024-07-25 13:23:14 发布