小白的GAN网络学习(一)

生成对抗网络 GAN(Generative Adversarial Network)(一)

一.什么是GAN

机器学习模型大体可分为两类,生成模型(Generative Model)和判别模型(Discriminative Model),判别模型(D)需要输入变量,通过某种模型来预测,生成模型(G)是给定某种隐含信息,来随机产生观测数据

GAN主要包括了两个部分,即生成器Generator判别器Discriminator

  • 生成器(G)主要用来学习真实图像分布,从而让自身生成的图像更加真实,以骗过判别器
  • 判别器(D)则需要对接受的图片进行真假判别
  • 在训练过程中,生成器(G)努力让生成的图像更加真实,而判别器(D)则努力识别出图像的真假,这个过程相当于二人博弈,随着时间的推移,生成器(G)和判别器(D)在不断的进行对抗,同时生成器(G)的质量在不断提高
    • 对于给定的真实图片,判别器(D)要为其打上标签1
    • 对于给定的生成图片,判别器(D)要为其打上标签0
    • 对于生成器传给判别器的图片,生成器希望判别器打上标签1
    • 博弈的结果:在最理想的状态下G可以生成足以“以假乱真”的图像G(z),对于D来说,它难以判定G生成的图片究竟是不是真实的,因此D(G(z))=0.5,这样做的目的:得到了一个生成模型G,可以用来生成图片

算法实现

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-3XcNcxyl-1587697830856)(C:\Users\pxr\Desktop\一些文件夹\md图片\GAN_algorithm3.png)]

基于mnist数据集的TF代码
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt
import numpy
%matplotlib inline

# 加载数据集,由于只需要训练数据,所以测试数据用占位符
(images,labels),(_,_)=keras.datasets.mnist.load_data()

# 维度扩展
# 扩展后的维度是(28,28,1)
images=np.expand_dims(images,-1)

# 归一化处理,由于生成器模型使用的激活函数是tanh,所以归一化到[-1,1]的区间上
images=images/127.5 - 1

# 设定超参数
BATCH_SIZE=256
BUFFER_SIZE=60000

#创建数据管道
dataset=tf.data.Dataset.from_tensor_slices(images)
dataset=dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

# 定义生成器模型
def generate_model():
    model=keras.Sequential()
    # 这里的input_shape就是表示用一个长度为100的随机向量当做输入
    # 在GAN网络中,一般不用偏置项,所以置为False
    model.add(layers.Dense(256,input_shape=(100,),use_bias=False)
	# 添加BN层,可以加快数据的收敛              
    model.add(layers.BatchNormalization())
	# 在生成器模型中,除了最后一层,中间层的激活函数一般是relu              
	model.add(layers.ReLU())  
              
	model.add(layers.Dense(512,use_bias=False)) 
	model.add(layers.BatchNormalization())   
	model.add(layers.ReLU())    
              
	model.add(layers.Dense(28*28*1,use_bias=False)) 
	model.add(layers.BatchNormalization())
	model.add(layers.Activation('tanh')) 
	
    # 最后输出的是一张图片,所以需要reshape
	model.add(layers.Reshape((28,28,1)))  
	return model   
#-----------------------------------------------------------------
# 定义判别器模型
def discriminate_model():
	model=keras.Sequential()
	# 由于判别器接收的是一张图片,所以在最开始需要展平              
	model.add(layers.Flatten())              
	
	model.add(layers.Dense(512,use_bias=False))
	model.add(layers.BatchNormalization())
	# 在判别器模型中,常用的激活函数是leakyrelu              
	model.add(layers.LeakyReLU())    
              
	model.add(layers.Dense(256,use_bias=False))
	model.add(layers.Batch_Normalization()) 
	model.add(layers.LeakyReLU())
	
    # 这里没有添加激活函数          
	model.add(layers.Dense(1))       
	return model  
#----------------------------------------------------------------------
# 定义损失函数
# 由于我们要判断一张图片是真实的还是生成的,也就是True/False的问题,所以是一个二元分类问题
# 由于最后的输出结果没有经过激活函数,所以这里的from_logits=True              
cross_entropy=keras.losses.BinaryCrossentropy(from_logits=True)              
# 定义判别器损失
def discriminate_loss(real_out,fake_out):
	# 我们希望判别器对真实的图片的预测是1              
	real_loss=cross_entropy(tf.ones_like(real_out),real_out)  
	# 我们希望判别器对生成图片的预测是0              
	fake_loss=cross_entropy(tf.zeros_like(fake_out),fake_out)              	   	
	total_loss=real_loss+fake_loss
    return total_loss
              
# 定义生成器损失
def generate_loss(fake_out):
	# 我们希望生成器生成的图片越真实越好              
	return cross_entropy(tf.ones_like(fake_out),fake_out)  
#-----------------------------------------------------------------
# 定义优化器
generate_opt=keras.optimizers.Adam(1e-4)
discriminate_opt=keras.optimizers.Adam(1e-4)
#-----------------------------------------------------------------
EPOCHS=100
noise_dim=100
num_exp_to_generate=16
# 随机噪声种子              
seed=tf.random.normal([num_exp_to_generate,noise_dim]) 
              
# 初始化模型
generator=generate_model()
discriminator=discriminate_model()
              
# 定义训练函数
def train_step(images):
	noise=tf.random.normal([BATCH_SIZE,noise_dim])
	with tf.GradientTape() as gen_tape,tf.GradientTape() as disc_tape:
		real_out=discriminator(image,training=True)
		gen_image=generator(noise,training=True)
		fake_out=discriminator(gen_image,training=True)
		gen_loss=generate_loss(fake_out)
		disc_loss=discriminate_loss(real_out,fake_out)    
	
	gradient_gen=gen_tape.gradient(gen_loss,generate.trainable_variables)	
	gradient_disc=disc_tape.gradient(disc_loss,discriminate.trainable_variables)
	generate_opt.apply_gradient(zip(gradient_gen,generate.trainable_variables))            
	discriminate_opt.apply_gradient(zip(gradient_disc,discriminate.trainable_variables))    
#------------------------------------------------------------------
# 定义绘图函数
def generate_plot_image(gen_model,test_noise):
	pre_image=gen_model(test_noise,training_False) 
	fig=plt.figure(figsize=(4,4))
	for i in range(pre_image.shape[0]):
		plt.subplot(4,4,i+1)
		# 由于我们最开始归一化到[-1,1],所以要还原到[0,1]              
		plt.imshow((pre_image[i,:,:]+1)/2,cmap='gray')
		plt.axis('off')
	plt.show()  
#-------------------------------------------------------------------
def train(dataset,epochs):
	for epoch in range(epochs):
		for image_batch in dataset:
              train_step(image_batch)
        generate_plot_image(generator,seed)
#-----------------------------------------------------------------
train(dataset,EPOCHS)              


  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值