python自编码器故障分类_可变自动编码器(VAE)显示不一致的输出

本文展示了如何使用多层感知机(MLP)构建变分自编码器(VAE)来对MNIST数据集进行分类。通过训练,VAE的编码器可以生成潜在向量,而解码器则能够从均值为0,标准差为1的高斯分布中采样生成MNIST数字。实验中,对比了原始输入、解码器重建和VAE的输出。
摘要由CSDN通过智能技术生成

'''Example of VAE on MNIST dataset using MLPThe VAE has a modular design. The encoder, decoder and VAEare 3 models that share weights. After training the VAE model,the encoder can be used to generate latent vectors.The decoder can be used to generate MNIST digits by sampling thelatent vector from a Gaussian distribution with mean=0 and std=1.# Reference[1] Kingma, Diederik P., and Max Welling."Auto-encoding variational bayes."https://arxiv.org/abs/1312.6114'''from__future__importabsolute_importfrom__future__importdivisionfrom__future__importprint_functionfromkeras.layersimportLambda,Input,Densefromkeras.modelsimportModelfromkeras.datasetsimportmnistfromkeras.lossesimportmse,binary_crossentropyfromkeras.utilsimportplot_modelfromkerasimportbackendasKimportnumpyasnpimportmatplotlib.pyplotaspltimportos# reparameterization trick# instead of sampling from Q(z|X), sample eps = N(0,I)# z = z_mean + sqrt(var)*epsdefsampling(args):"""Reparameterization trick by sampling fr an isotropic unit Gaussian.# Arguments:args (tensor): mean and log of variance of Q(z|X)# Returns:z (tensor): sampled latent vector"""z_mean,z_log_var=argsbatch=K.shape(z_mean)[0]dim=K.int_shape(z_mean)[1]# by default, random_normal has mean=0 and std=1.0epsilon=K.random_normal(shape=(batch,dim))returnz_mean+K.exp(0.5*z_log_var)*epsilon# MNIST dataset(x_train,y_train),(x_test,y_test)=mnist.load_data()image_size=x_train.shape[1]original_dim=image_size*image_sizex_train=np.reshape(x_train,[-1,original_dim])x_test=np.reshape(x_test,[-1,original_dim])x_train=x_train.astype('float32')/255x_test=x_test.astype('float32')/255# network parametersinput_shape=(original_dim,)intermediate_dim=512batch_size=128latent_dim=32epochs=50# VAE model = encoder + decoder# build encoder modelinputs=Input(shape=input_shape,name='encoder_input')x=Dense(intermediate_dim,activation='relu')(inputs)z_mean=Dense(latent_dim,name='z_mean')(x)z_log_var=Dense(latent_dim,name='z_log_var')(x)# use reparameterization trick to push the sampling out as input# note that "output_shape" isn't necessary with the TensorFlow backendz=Lambda(sampling,output_shape=(latent_dim,),name='z')([z_mean,z_log_var])# instantiate encoder modelencoder=Model(inputs,[z_mean,z_log_var,z],name='encoder')encoder.summary()#plot_model(encoder, to_file='vae_mlp_encoder.png', show_shapes=True)# build decoder modellatent_inputs=Input(shape=(latent_dim,),name='z_sampling')x=Dense(intermediate_dim,activation='relu')(latent_inputs)outputs=Dense(original_dim,activation='sigmoid')(x)# instantiate decoder modeldecoder=Model(latent_inputs,outputs,name='decoder')decoder.summary()#plot_model(decoder, to_file='vae_mlp_decoder.png', show_shapes=True)# instantiate VAE modeloutputs=decoder(encoder(inputs)[2])vae=Model(inputs,outputs,name='vae_mlp')if__name__=='__main__':models=(encoder,decoder)data=(x_test,y_test)# VAE loss = mse_loss or xent_loss + kl_loss#reconstruction_loss = mse(inputs, outputs)reconstruction_loss=binary_crossentropy(inputs,outputs)reconstruction_loss*=original_dimkl_loss=1+z_log_var-K.square(z_mean)-K.exp(z_log_var)kl_loss=K.sum(kl_loss,axis=-1)kl_loss*=-0.5vae_loss=K.mean(reconstruction_loss+kl_loss)vae.add_loss(vae_loss)vae.compile(optimizer='adam')vae.summary()# train the autoencodervae.fit(x_train,epochs=epochs,batch_size=batch_size,validation_data=(x_test,None))#vae.save_weights('vae_mlp_mnist.h5')z_mean,z_log_var,z=encoder.predict(x_test)decoded_imgs=decoder.predict(z_mean)Y_img=vae.predict(x_test)n=10# how many digits we will displayplt.figure(figsize=(20,4))foriinrange(n):# display originalax=plt.subplot(3,n,i+1)plt.imshow(x_test[i].reshape(28,28))plt.gray()ax.get_xaxis().set_visible(False)ax.get_yaxis().set_visible(False)# display reconstructionax=plt.subplot(3,n,i+1+n)plt.imshow(decoded_imgs[i].reshape(28,28))plt.gray()ax.get_xaxis().set_visible(False)ax.get_yaxis().set_visible(False)# display reconstruction 2ax=plt.subplot(3,n,i+1+2*n)plt.imshow(Y_img[i].reshape(28,28))plt.gray()ax.get_xaxis().set_visible(False)ax.get_yaxis().set_visible(False)plt.show()

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值