视频地址:
https://www.bilibili.com/video/BV1Fp4y1o7Kw?t=225.9
# 自编码器学习代码,视频地址:https://www.bilibili.com/video/BV1Fp4y1o7Kw?t=225.9
# 原视频地址:https://www.youtube.com/watch?v=zp8clK9yCro
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets,transforms
import matplotlib.pyplot as plt
mnist_data = datasets.MNIST(root='./data_mnist', train=True, transform=transforms.ToTensor(),download=True)
data_loader = torch.utils.data.DataLoader(dataset=mnist_data,
batch_size=64,
shuffle=True)
dataiter = iter(data_loader) # iter()Python中的一个内置函数,该对象可以用于迭代可迭代对象(如列表、元组、字典等)。迭代器对象允许我们逐个访问可迭代对象中的元素,而不必一次性加载整个可迭代对象到内存中
images, labels = dataiter.next() # next() 函数逐个访问迭代器中的元素
print(torch.min(images), torch.max(images))
class Autoencoder(nn.Module):
def __init__(self):
super(Autoencoder, self).__init__()
self.encoder = nn.Sequential(
nn.Linear(28*28, 128), # 图片像素大小28*28
nn.ReLU(),
nn.Linear(128,64),
nn.ReLU(),
nn.Linear(64,12),
nn.ReLU(),
nn.Linear(12,3)
)
self.decoder = nn.Sequential(
nn.Linear(3, 12),
nn.ReLU(),
nn.Linear(12, 64),
nn.ReLU(),
nn.Linear(64, 128),
nn.ReLU(),
nn.Linear(128, 28*28),
nn.Sigmoid()
)
def forward(self,x): # 前向传播
encoded = self.encoder(x)
decoded = self.decoder(encoded)
return decoded
model = Autoencoder() # 实例化model
criterion = nn.MSELoss() # 交叉熵损失函数
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5) # 优化器
num_epochs = 10 # 定义训练多少个epoch
outputs = [] # outputs一个空列表
for epoch in range(num_epochs):
for(img,_) in data_loader: # _ 表示一个占位符,表示我们在这里不需要使用这个变量的值。在这个特定的情境中,data_loader 返回的每个批次的数据是一个元组 (img, label),其中 img 是图像数据,label 是对应的标签。然而,在这个循环中,我们似乎只关心图像数据而不关心标签,因此使用下划线 _ 表示我们暂时不需要使用这个变量的值。这样做的目的是为了提高代码的可读性,告诉阅读代码的人,我们暂时不关心标签这个变量的值
img = img.reshape(-1,28*28) # 展平操作,全连接层需要展平之后输入。将某一个维度设为-1,让PyTorch根据张量中元素的总数量和其他维度的大小,自动计算出该维度的大小,从而保证张量中所有元素的数量不变
recon = model(img)
loss = criterion(recon,img)
optimizer.zero_grad() #
loss.backward()
optimizer.step()
print(f'Epoch:{epoch+1}, Loss:{loss.item():.4f}')
outputs.append((epoch, img, recon)) # 在Python中,append() 方法接受一个参数,这个参数通常是要添加到列表中的元素。如果要添加的元素是一个元组,那么这个元组本身就是一个元素,因此需要将它放在括号中以表示一个整体
for k in range(0, num_epochs, 4): # 从0到num_epochs-1,步长为4。range用于for循环左闭右开相当于[0,epochs)
plt.figure(figsize=(9, 2)) # 创建一个大小为9x2英寸的新图形。
plt.gray() # 将图形的色彩模式设置为灰度
imgs = outputs[k][1].detach().numpy() # 从outputs列表中获取第k个元素(对应于第k个epoch)的原始图片数据和重建图片数据。这里的outputs[k]返回一个元组,第1个元素是原始图片,第2个元素是重建图片。.detach()用于分离张量,.numpy()用于将张量转换为NumPy数组
recon = outputs[k][2].detach().numpy()
for i, item in enumerate(imgs):
if i >= 9: break
plt.subplot(2, 9, i+1)
item = item.reshape(-1, 28, 28)
plt.imshow(item[0])
# 遍历原始图片数据中的前9个样本。对于每个样本,将其reshape为28x28的二维数组,并在子图中显示。plt.subplot(2, 9, i + 1)用于创建一个2x9的子图区域,其中第i+1个位置显示当前样本的图像
for i, item in enumerate(recon):
if i >= 9: break
plt.subplot(2, 9, 9+i+1)
item = item.reshape(-1, 28, 28)
plt.imshow(item[0])
# 遍历重建图片数据中的前9个样本,处理方式与原始图片相同,但是在子图中的位置从第10个开始