自编码器重建 Fashion_mnist 数据集

自编码器

from PIL import Image
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import Sequential, layers
import numpy as np
from matplotlib import pyplot as plt

加载数据集

(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train, x_test = x_train.astype(np.float32) / 255., x_test.astype(np.float32) / 255.
# we do not need label
train_db = tf.data.Dataset.from_tensor_slices(x_train)
train_db = train_db.shuffle(buffer_size=512).batch(512)
test_db = tf.data.Dataset.from_tensor_slices(x_test)
test_db = test_db.batch(512)

print(x_train.shape, y_train.shape)
print(x_test.shape, y_test.shape)
(60000, 28, 28) (60000,)
(10000, 28, 28) (10000,)

构建网络

class AutoEncoder(keras.Model):

    def __init__(self):
        super(AutoEncoder, self).__init__()

        # Encoders
        self.encoder = Sequential([
            layers.Dense(256, activation=tf.nn.relu),
            layers.Dense(128, activation=tf.nn.relu),
            layers.Dense(20)
        ])

        # Decoders
        self.decoder = Sequential([
            layers.Dense(128, activation=tf.nn.relu),
            layers.Dense(256, activation=tf.nn.relu),
            layers.Dense(784)
        ])

    # 前向计算
    def call(self, inputs, training=None):
        # [b, 784] => [b, 10]
        h = self.encoder(inputs)
        # [b, 10] => [b, 784]
        x_hat = self.decoder(h)

        return x_hat

网络训练

def save_images(imgs, name):
    new_im = Image.new('L', (280, 280))

    index = 0
    for i in range(0, 280, 28):
        for j in range(0, 280, 28):
            im = imgs[index]
            im = Image.fromarray(im, mode='L')
            new_im.paste(im, (i, j))
            index += 1

    new_im.save(name)
model = AutoEncoder()
model.build(input_shape=(None, 28 * 28))
model.summary()

optimizer = tf.optimizers.Adam(lr=1e-3)
Model: "auto_encoder"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
sequential (Sequential)      multiple                  236436    
_________________________________________________________________
sequential_1 (Sequential)    multiple                  237200    
=================================================================
Total params: 473,636
Trainable params: 473,636
Non-trainable params: 0
_________________________________________________________________

开始训练

for epoch in range(20):

    for step, x in enumerate(train_db):

        #[b, 28, 28] => [b, 784]
        x = tf.reshape(x, [-1, 28 * 28])
        # 构建梯度记录器
        with tf.GradientTape() as tape:
            # 前向计算
            x_rec_logits = model(x)
            # 计算损失函数
            rec_loss = tf.losses.binary_crossentropy(x, x_rec_logits, from_logits=True)
            rec_loss = tf.reduce_mean(rec_loss)
        # 自动求导
        grads = tape.gradient(rec_loss, model.trainable_variables)
        # 更新网络
        optimizer.apply_gradients(zip(grads, model.trainable_variables))

    # 打印训练误差
    print("epoch: ", epoch, "loss: ", float(rec_loss))

        
    # 从测试集采集图片
    x = next(iter(test_db))
    logits = model(tf.reshape(x, [-1, 784]))
    # 讲输出值转化为像素值
    x_hat = tf.sigmoid(logits)
    # [b, 784] => [b, 28, 28] 恢复原始数据格式
    x_hat = tf.reshape(x_hat, [-1, 28, 28])

    # [b, 28, 28] => [2b, 28, 28]
    # 输入的前 50 张+重建的前 50 张图片合并
    x_concat = tf.concat([x[:50], x_hat[:50]], axis=0)
    # 恢复为 0-255 的范围
    x_concat = x_concat.numpy() * 255.
    # 转换为整型
    x_concat = x_concat.astype(np.uint8)
    save_images(x_concat, 'ae_images/mnist_%d.png'%epoch)

epoch:  0 loss:  0.1876431256532669
epoch:  1 loss:  0.14163847267627716
epoch:  2 loss:  0.12352141737937927
epoch:  3 loss:  0.11942803859710693
epoch:  4 loss:  0.11525192111730576
epoch:  5 loss:  0.10021436214447021
epoch:  6 loss:  0.10526927560567856
epoch:  7 loss:  0.10288294404745102
epoch:  8 loss:  0.10139968246221542
epoch:  9 loss:  0.10215207189321518
epoch:  10 loss:  0.0961870551109314
epoch:  11 loss:  0.091026671230793
epoch:  12 loss:  0.09655070304870605
epoch:  13 loss:  0.09417414665222168
epoch:  14 loss:  0.08978977054357529
epoch:  15 loss:  0.08931374549865723
epoch:  16 loss:  0.08951258659362793
epoch:  17 loss:  0.08937102556228638
epoch:  18 loss:  0.09456444531679153
epoch:  19 loss:  0.08556753396987915
def printImage(images):
    plt.figure(figsize=(10, 10))
    for i in range(20):
        plt.subplot(5,5,i+1)
        plt.xticks([])
        plt.yticks([])
        plt.grid(False)
        plt.imshow(images[i], cmap=plt.cm.binary)
x = next(iter(test_db))
logits = model(tf.reshape(x, [-1, 784]))
# 讲输出值转化为像素值
x_hat = tf.sigmoid(logits)
# [b, 784] => [b, 28, 28] 恢复原始数据格式
x_hat = tf.reshape(x_hat, [-1, 28, 28])

# [b, 28, 28] => [2b, 28, 28]
# 输入的前 50 张+重建的前 50 张图片合并
x_concat = tf.concat([x[:10], x_hat[:10]], axis=0)
# 恢复为 0-255 的范围
x_concat = x_concat.numpy() * 255.
# 转换为整型
x_concat = x_concat.astype(np.uint8)
printImage(x_concat)
  • 上面 5 行是原始图片, 下面 5 行是 重建后的图片
    在这里插入图片描述
保存本地的图片:

第一次 epoch
左边 5 列是原图片,右边 5 列是经过重建后的。可以看到,此时还不是很清楚
在这里插入图片描述

第十次 epoch
在这里插入图片描述
第二十次 epoch
在这里插入图片描述

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值