自编码器学习笔记一

视频地址: 

https://www.bilibili.com/video/BV1Fp4y1o7Kw?t=225.9
#  自编码器学习代码,视频地址:https://www.bilibili.com/video/BV1Fp4y1o7Kw?t=225.9
#  原视频地址:https://www.youtube.com/watch?v=zp8clK9yCro
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets,transforms
import matplotlib.pyplot as plt

mnist_data = datasets.MNIST(root='./data_mnist', train=True, transform=transforms.ToTensor(),download=True)

data_loader = torch.utils.data.DataLoader(dataset=mnist_data,
                                          batch_size=64,
                                          shuffle=True)

dataiter = iter(data_loader)  #  iter()Python中的一个内置函数,该对象可以用于迭代可迭代对象(如列表、元组、字典等)。迭代器对象允许我们逐个访问可迭代对象中的元素,而不必一次性加载整个可迭代对象到内存中
images, labels = dataiter.next()  #  next() 函数逐个访问迭代器中的元素
print(torch.min(images), torch.max(images))

class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(28*28, 128),  # 图片像素大小28*28
            nn.ReLU(),
            nn.Linear(128,64),
            nn.ReLU(),
            nn.Linear(64,12),
            nn.ReLU(),
            nn.Linear(12,3)
        )

        self.decoder = nn.Sequential(
            nn.Linear(3, 12),
            nn.ReLU(),
            nn.Linear(12, 64),
            nn.ReLU(),
            nn.Linear(64, 128),
            nn.ReLU(),
            nn.Linear(128, 28*28),
            nn.Sigmoid()
        )

    def forward(self,x):  # 前向传播
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded

model = Autoencoder()  #  实例化model
criterion = nn.MSELoss()  #  交叉熵损失函数
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)  #  优化器

num_epochs = 10  # 定义训练多少个epoch
outputs = []    #  outputs一个空列表
for epoch in range(num_epochs):
    for(img,_) in data_loader:  #  _ 表示一个占位符,表示我们在这里不需要使用这个变量的值。在这个特定的情境中,data_loader 返回的每个批次的数据是一个元组 (img, label),其中 img 是图像数据,label 是对应的标签。然而,在这个循环中,我们似乎只关心图像数据而不关心标签,因此使用下划线 _ 表示我们暂时不需要使用这个变量的值。这样做的目的是为了提高代码的可读性,告诉阅读代码的人,我们暂时不关心标签这个变量的值
        img = img.reshape(-1,28*28)  #  展平操作,全连接层需要展平之后输入。将某一个维度设为-1,让PyTorch根据张量中元素的总数量和其他维度的大小,自动计算出该维度的大小,从而保证张量中所有元素的数量不变
        recon = model(img)
        loss = criterion(recon,img)

        optimizer.zero_grad()  # 
        loss.backward()
        optimizer.step()

    print(f'Epoch:{epoch+1}, Loss:{loss.item():.4f}')
    outputs.append((epoch, img, recon))  #  在Python中,append() 方法接受一个参数,这个参数通常是要添加到列表中的元素。如果要添加的元素是一个元组,那么这个元组本身就是一个元素,因此需要将它放在括号中以表示一个整体

for k in range(0, num_epochs, 4):  #  从0到num_epochs-1,步长为4。range用于for循环左闭右开相当于[0,epochs)
    plt.figure(figsize=(9, 2))  #  创建一个大小为9x2英寸的新图形。
    plt.gray()  #  将图形的色彩模式设置为灰度
    imgs = outputs[k][1].detach().numpy()  #  从outputs列表中获取第k个元素(对应于第k个epoch)的原始图片数据和重建图片数据。这里的outputs[k]返回一个元组,第1个元素是原始图片,第2个元素是重建图片。.detach()用于分离张量,.numpy()用于将张量转换为NumPy数组
    recon = outputs[k][2].detach().numpy()
    for i, item in enumerate(imgs):
        if i >= 9: break
        plt.subplot(2, 9, i+1)
        item = item.reshape(-1, 28, 28)
        plt.imshow(item[0])
#  遍历原始图片数据中的前9个样本。对于每个样本,将其reshape为28x28的二维数组,并在子图中显示。plt.subplot(2, 9, i + 1)用于创建一个2x9的子图区域,其中第i+1个位置显示当前样本的图像

    for i, item in enumerate(recon):
        if i >= 9: break
        plt.subplot(2, 9, 9+i+1)
        item = item.reshape(-1, 28, 28)
        plt.imshow(item[0])
#  遍历重建图片数据中的前9个样本,处理方式与原始图片相同,但是在子图中的位置从第10个开始

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值