小白的GAN网络学习(三)- CGAN

条件GAN - Condition GAN (CGAN)

原始GAN的缺点
  • 生成的图像是随机的,不可预测,无法控制网络输出特定的图片,生成目标不明确,可控性不强

  • 针对原始GAN不能生成具有特定属性图片的问题,CGAN的核心在于将属性信息y融入生成器G和判别器D中,属性y可以是任何标签类信息

    • 原来G接收的是noise,现在还要附加一个条件y
    • 原来D接收的是G生成的图片和原始图片,现在也要附加一个条件y
    • 将无监督学习转化为有监督学习
    • 原始GAN:
      • m i n G m a x D V ( D , G ) = E x − p d a t a ( x ) [ log ⁡ D ( x ) ] + E x − p z ( x ) [ log ⁡ 1 − D ( G ( z ) ) ] min_Gmax_DV(D,G)=E_{x-p_{data}(x)}[\log{D(x)}]+E_{x-p_{z}(x)}[\log{1-D(G(z))}] minGmaxDV(D,G)=Expdata(x)[logD(x)]+Expz(x)[log1D(G(z))]
    • CGAN:
      • m i n G m a x D V ( D , G ) = E x − p d a t a ( x ) [ log ⁡ D ( x ∣ y ) ] + E x − p z ( x ) [ log ⁡ 1 − D ( G ( z ∣ y ) ) ] min_Gmax_DV(D,G)=E_{x-p_{data}(x)}[\log{D(x|y)}]+E_{x-p_{z}(x)}[\log{1-D(G(z|y))}] minGmaxDV(D,G)=Expdata(x)[logD(xy)]+Expz(x)[log1D(G(zy))]
      • 其中x表示原始图像,z表示noise
  • CGAN的缺陷:

    生成的图像边缘模糊,分辨率不够

基于mnist的CGAN代码实现
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline

# 处理数据
(images,labels),(_,_)=keras.datasets.mnist.load_data()
images=images/127.5 - 1
images=np.expand_dims(images,-1)
dataset=tf.data.Dataset.from_tensor_slices((images,labels))

BATCH_SIZE=128
noise_dim=50
BUFFER_SIZE=images.shape[0]

dataset=dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
#--------------------------------------------------------------------
# 定义生成器模型
def generate_model():
    seed=layers.Input(shape=(noise_dim,))
    label=layers.Input(shape(()))
    
    # 添加一个Embedding层,目的是将label和noise合起来
    x=layers.Embedding(10,50,input_length=1)(label)
    x=layers.concatenate([seed,x])(x)
    x=layers.Dense(3*3*128,use_bias=False)(x)
    x=layers.Reshape((3,3,128))(x)
    x=layers.BatchNormalization()(x)
    x=layers.ReLU()(x)
    
	# 注意:这里没有指定padding的参数,则默认为valid,也就是说,经过这一层后图像会变成(7,7,64)
    x=layers.Conv2DTranspose(64,(3,3),strides=(2,2),use_biase=False)(x)
    x=layers.BatchNormalization()(x)
    x=layers.ReLU()(x)
    
    # 经过该层后,图像变成(14,14,32)
    x=layes.Conv2DTranspose(32,(3,3),strides=(2,2),padding='same',use_bias=False)
    x=layers.BatchNormalization()(x)
    x=layers.ReLU()(x)
    
    # 经过该层后,图像变成(28,28,1)
    x=layers.Conv2DTranspose(1,(3,3),strides=(2,2),use_bias=False,padding='same')(x)
    x=layers.Activation('tanh')(x)
    
    model=keras.models.Model(inputs=(seed,label),outputs=x)
    
    return model
#------------------------------------------------------------------------
# 定义判别器模型
def discriminate_model():
    image=layers.Input(shape=(28,28,1))
    label=layers.Input(shape=(()))
    
    # 同样的,一个Embedding层将image和label合起来
    x=layers.Embedding(10,28*28,input_length=1)(label)
    x=layers.Reshape((28,28,1))(x)
    x=layers.concatenate([x,image])
    
    x=layers.Conv2D(32,(3,3),strides=(2,2),padding='same',use_bias=False)(x)
    x=layers.BatchNormalization()(x)
    x=layers.LeakyReLU()(x)
    x=layers.Dropout(0.3)(x)
    
    x=layers.Conv2D(64,(3,3),strides=(2,2),padding='same',use_bias=False)(x)
    x=layers.BatchNormalization()(x)
    x=layers.LeakyReLU()(x)
    x=layers.Dropout(0.3)(x)

    x=layers.Conv2D(128,(3,3),strides=(2,2),padding='same',use_bias=False)(x)
    x=layers.BatchNormalization()(x)
    x=layers.LeakyReLU()(x)
    x=layers.Dropout(0.3)(x)

    x=layers.Flatten()(x)
    
    out=layers.Dense(1)(x)
    
    model=keras.models.Model(inputs=(image,label),outputs=out)
    return model
#-------------------------------------------------------------------
# 实例化对象并自定义损失函数
generator=generate_model()
discriminator=discriminate_model()

bce=keras.losses.BinaryCrossentropy(from_logits=True)

# 这里的损失函数和GAN/DCGAN网络的损失函数并无差别
def disc_loss(real_out,fake_out):
    real_loss=bce(tf.ones_like(real_out),real_out)
    fake_loss=bce(tf.zeros_like(fake_out),fake_out)
    total_loss=real_loss+fake_loss
    return total_loss

def gen_loss(fake_out):
    return bce(tf.ones_like(fake_out),fake_out)
#---------------------------------------------------------------------
# 定义优化器&自定义训练
gen_opt=keras.optimizers.Adam(1e-5)
disc_opt=keras.optimizers.Adam(1e-5)

@tf.function
def train_step(image,label):
    size=label.shape[0]
    noise=tf.random.normal((size,noise_dim))
    
    with tf.GradientTape() as gen_tape,tf.GradientTape() as disc_tape:
        gen_image=generator((noise,label),training=True)
        real_out=discriminator((image,label),training=True)
        fake_out=discriminator((gen_image,label),training=True)
        
        discriminate_loss=disc_loss(real_out,fake_out)
        generate_loss=gen_loss(fake_out)
    
	gen_grad=gen_tape.gradient(generate_loss,generator.trainable_variables)
    disc_grad=disc_tape.gradient(discriminate_loss,discriminator.trainable_variables)
    gen_opt.apply_gradient(zip(gen_grad,generator.trainable_variables))
    disc_opt.apple_gradient(zip(disc_grad,discriminator.trainable_varialbes))
#---------------------------------------------------------------------
# 自定义绘图函数
def plot_gen_image(model,noise,label,epoch_num):
    print('Epoch:',epoch_num)
    gen_image=model((noise,label),training=False)
    # 压缩维度,将28*28*1 转换成28*28的图像
    gen_image=tf.squeeze(gen_image)  
    fig=plt.figure(figsize=(10,1))
    for i in range(gen_image.shape[0]):
        plt.subplot(1,10,i+1)
        plt.imshow((gen_image[i,:,:]+1)/2)
        plt.axes('off')
    plt.show()
#-----------------------------------------------------------------------
# 启动函数
noise_seed=tf.random.normal([10,noise_dim])
label_seed=np.random.randint(0,10,size=(10))
def train(dataset,epochs):
    for epoch in range(epochs):
        for image_batch,label_batch in dataset:
            train_step(image_batch,label_batch)
        if epoch % 10 ==0:
            plot_gen_image(generator,noise_seed,label_seed,epoch)
    plot_gen_image(generator,noise_seed,label_seed,epoch)
    
EPOCHS=200
train(dataset,EPOCHS)

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值