3.2自编码器(变分自编码器,VAE)

拓展(Keras + fashion_mnist)

承接上一篇博客:3.自编码器(变分自编码器,VAE)

# 加载库
import numpy as np
import matplotlib.pyplot as plt

from keras.layers import Input, Dense, Lambda
from keras.models import Model, Sequential
from keras import backend as K
from keras import objectives
from keras.datasets import fashion_mnist


# 加载数据并训练,CPU训练的速度还算能忍受
(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()

# 转浮点 + 归一化
x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.

# reshape数据形状,适用于Dense层的input
x_train = x_train.reshape((len(x_train), np.prod(x_train.shape[1:])))
x_test = x_test.reshape((len(x_test), np.prod(x_train.shape[1:])))

print(x_train.shape, y_train.shape, x_test.shape, y_test.shape) # (60000, 784) (60000,) (10000, 784) (10000,)

运行结果:
在这里插入图片描述

# 显示出前16张图片
plt.figure(num='fig1', figsize=(10,10), dpi=75, facecolor='#666666', edgecolor='#FF0000')
for i in range(0,16):
    ax = plt.subplot(4, 4, i+1, label='one')
    plt.imshow(x_test[i].reshape(28, 28))
    plt.xticks([]) # 去除x坐标
    plt.yticks([])
    plt.title('label:{0}'.format(y_test[i])) # 子框的标题
    plt.ylabel('') # y轴注释
    plt.xlabel('') # x轴注释

plt.show()
plt.close() # 释放内存

运行结果:
在这里插入图片描述

# 定义模型参数
encoding_dim = 2 # 一个是均值一个是方差

# encoder部分
input_img = Input(shape=(784,)) # 28*28
encoded1 = Dense(units=128, activation='relu')(input_img)
encoded2 = Dense(units=32, activation='relu')(encoded1)

z_mean = Dense(encoding_dim)(encoded2)
z_log_var = Dense(encoding_dim)(encoded2)

# Lambda层不参与训练,只参与计算,用于后面产生新的z
def sampling(args):
    z_mean, z_log_var = args
    epsilon = K.random_normal(shape=(100, encoding_dim), mean=0.) # 生成batch_size个2维的均值为0方差为1的正态分布数据
    return z_mean + K.exp(z_log_var / 2) * epsilon # 得到指定均值和方差的数据

z = Lambda(sampling, output_shape=(encoding_dim,), name='Z_layer')([z_mean, z_log_var]) # 生成新的z数据


# decoder部分,与Encoder层对应
decoded1_layer = Dense(32, activation='relu')
decoded2_layer = Dense(128, activation='relu')
decoded_out_layer = Dense(784, activation='sigmoid')# 包含784=28*28个神经元

decoded1 = decoded1_layer(z) # 2->32
decoded2 = decoded2_layer(decoded1) # 32->128
decoded_out = decoded_out_layer(decoded2) # 128->784


# 自定义总的损失函数并编译模型
def vae_loss(input_img, decoded_out):
    xent_loss = 784 * objectives.binary_crossentropy(input_img, decoded_out) # 计算交叉熵
    kl_loss = -0.5 * K.sum(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1) # KL距离
    return xent_loss + kl_loss

vae = Model(input_img, decoded_out) # Variational AutoEncoder变分自编码器
vae.compile(optimizer='rmsprop', loss=vae_loss)

# 提取VAE模型中的编码器 -> 可视化中间结果
encoder = Model(input=vae.input, output=vae.get_layer('Z_layer').output) 

# 构建解码器 -> 生成新图片
decoded_input = Input(shape=(2,)) # 2
decoded1 = decoded1_layer(decoded_input) # 2->32
decoded2 = decoded2_layer(decoded1) # 32->128
decoded_out = decoded_out_layer(decoded2) # 128->784
decoder = Model(input=decoded_input, output=decoded_out) 

vae.summary()
encoder.summary()
decoder.summary()

运行结果:
在这里插入图片描述在这里插入图片描述在这里插入图片描述

# 训练模型
vae.fit(x_train, x_train, # 将原图与解码后的图作对比计算损失值
    shuffle = True,
    epochs = 50, # 每个训练epoch后,数据会打乱
    batch_size = 100,
    validation_data = (x_test, x_test))

# 用编码器查看中间结果
x_test_encoded = encoder.predict(x_test, batch_size=100)
plt.figure(figsize=(6, 6))
plt.scatter(x_test_encoded[:,0], x_test_encoded[:,1], c=y_test)
plt.colorbar()
plt.show()

运行结果:
在这里插入图片描述在这里插入图片描述

# 查看整个模型的解码效果
decoded_imgs = vae.predict(x_test, batch_size=100)  # 测试集合输入查看器去噪之后输出。

# 测试集中选取10张可视化
n = 10
plt.figure(figsize=(20, 4))
for i in range(n):
    # 原图
    ax = plt.subplot(2, n, i + 1)
    plt.imshow(x_test[i].reshape(28, 28))
    plt.title('label:{0}'.format(y_test[i]))
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)

    # 解码后的图
    ax = plt.subplot(2, n, i + 1 + n)
    plt.imshow(decoded_imgs[i].reshape(28, 28))
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)

plt.show()

运行结果:

在这里插入图片描述

# 使用解码器来生成图片

# 用网格的方法产生一些二维数据,作为新的z输入到生成器,并将生成的x显示出来
n = 20
digit_size = 28
grid_x = np.linspace(-3, 3, n)
grid_y = np.linspace(-3, 3, n)
figure = np.zeros((digit_size * n , digit_size * n))

for i, xi in enumerate(grid_x):
    for j, yi in enumerate(grid_y):
        z_sample = np.array([[yi, xi]])
        x_decoded = decoder.predict(z_sample)
        digit = x_decoded[0].reshape(digit_size, digit_size)
        figure[(n-i-1)*digit_size : (n-i)*digit_size, j*digit_size : (j+1)*digit_size] = digit

plt.figure(figsize=(10, 10)) 
plt.imshow(figure)
plt.show()

运行结果:
在这里插入图片描述

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值