Tensorflow2.0笔记 - AutoEncoder做FashionMnist数据集训练

        本笔记记录自编码器做FashionMnist数据集训练,关于autoencoder的原理,请自行百度。

import os
import time
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics, Input,losses
from PIL import Image
from matplotlib import pyplot as plt
import numpy as np
from tensorflow.keras.models import Model

os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
#tf.random.set_seed(12345)
tf.__version__

#加载fashion mnist数据集
(x_train, _), (x_test, _) = datasets.fashion_mnist.load_data()
#图片像素数据范围限值到[0,1]
x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.

print (x_train.shape)
print (x_test.shape)

h_dim = 64 
class Autoencoder(Model):
  def __init__(self, h_dim):
    super(Autoencoder, self).__init__()
    self.h_dim = h_dim   
    #encoder层,[b, 28, 28] => [b, 784] => [b, h_dim]
    self.encoder = tf.keras.Sequential([
      layers.Flatten(),
      layers.Dense(256, activation='relu'),
      layers.Dense(h_dim, activation='relu'),
    ])
    #decoder层,[b, h_dim] => [b,784] => [b, 28, 28]
    self.decoder = tf.keras.Sequential([
      layers.Dense(784, activation='sigmoid'),
      #恢复成28x28的图片
      layers.Reshape((28, 28))
    ])

  def call(self, x):
    encoded = self.encoder(x)
    decoded = self.decoder(encoded)
    return decoded

model = Autoencoder(h_dim)

model.compile(optimizer='adam', loss=losses.MeanSquaredError())
model.fit(x_train, x_train,
                epochs=10,
                shuffle=True,
                validation_data=(x_test, x_test))


encoded_imgs = model.encoder(x_test).numpy()
decoded_imgs = model.decoder(encoded_imgs).numpy()
n = 10
plt.figure(figsize=(20, 4))
for i in range(n):
  #绘制原始图像
  ax = plt.subplot(2, n, i + 1)
  plt.imshow(x_test[i])
  plt.title("original")
  plt.gray()
  ax.get_xaxis().set_visible(False)
  ax.get_yaxis().set_visible(False)

  #绘制重建的图像
  ax = plt.subplot(2, n, i + 1 + n)
  plt.imshow(decoded_imgs[i])
  plt.title("reconstructed")
  plt.gray()
  ax.get_xaxis().set_visible(False)
  ax.get_yaxis().set_visible(False)
plt.show()

运行结果:

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
作为一个零基础的PyCharm和PyTorch学习者,你可以按照以下步骤来复现AutoEncoder程序: 1. 安装PyTorch和相关依赖:在PyCharm的项目中,你需要先安装PyTorch和其他必要的依赖库。你可以在PyTorch官方网站上找到安装指南。 2. 导入数据集:将原始AutoEncoder-fashion_mnist数据集和VAE-fashion_mnist数据集导入到你的项目中。这些数据集通常是以numpy数组或者其他常见的数据格式提供的。你可以使用PyTorch的数据加载器(如`torchvision.datasets`)来加载数据集。 3. 构建AutoEncoder模型:根据AutoEncoder的代码,你需要定义一个新的PyTorch模型。这通常涉及到创建一个继承自`torch.nn.Module`的类,并在其中定义模型的结构和操作。你可以使用PyCharm的代码编辑器来编写这些代码。 4. 定义损失函数和优化器:根据代码,你需要定义一个适当的损失函数(如均方误差)和优化器(如Adam)。这些函数可以在PyTorch中找到并导入。 5. 训练模型:使用原始AutoEncoder-fashion_mnist数据集,你可以编写训练循环来对模型进行训练。在每个训练迭代中,你需要传递输入数据并通过模型生成输出,然后计算损失并进行反向传播优化模型。这可以使用PyTorch的张量操作和优化器功能来实现。 6. 评估模型:使用VAE-fashion_mnist数据集,你可以编写评估代码来测试训练好的模型的性能。这可能涉及到计算模型在测试数据上的重建误差或其他指标。 7. 调整超参数:根据需要,你可能需要调整模型的超参数(如学习率、隐藏层大小等)。这可以通过修改代码中的参数值来实现。 请注意,以上步骤是一般性的指导,具体实现可能因代码和数据集而异。你需要仔细阅读提供的代码和相关文档,并根据需要进行适当的调整和修改。同时,你还可以利用PyCharm提供的代码提示、调试工具和其他功能来帮助你理解和调试代码。 祝你成功复现AutoEncoder程序!如果你在实践过程中遇到任何问题,欢迎随时向我提问。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

亦枫Leonlew

希望这篇文章能帮到你

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值