小白的GAN网络学习(四)- ACGAN

ACGAN

ACGAN & CGAN
  • CGAN通过在生成器和判别器中均使用标签信息进行训练,仅能产生特定标签的数据
  • ACGAN是CGAN的另一种实现,既使用标签信息进行训练,同时也重建标签信息
  • 生成器的输入包括class 和 noise 两部分,其中class为训练数据的标签(batch_size,channel,height,width)
  • 判别器的输入为图片(生成图片和真实图片),输出为两部分:
    • 1.源数据真假的判断,形状为(batch_size,1)
    • 2.输入数据的分类结果,形状为(batch_size,class_num)
    • 所以判别器的最后一层有连个并列的全连接层,分别到这两部分的输出结果,即判别器的输出有两个张量(真假判断张量和分类结果张量)
ACGAN的损失函数
  • 对判别器而言,既希望分类正确,又希望能够正确分别数据的真/假
    • 判别器的损失函数: L D = L S + L C L_D=L_S+L_C LD=LS+LC
      • 判断真假损失: L S = E [ log ⁡ P ( S = r e a l ∣ x r e a l ) ] + E [ log ⁡ P ( S = f a k e ∣ x f a k e ) ] L_S=E[\log{P(S=real|x_{real})}]+E[\log{P(S=fake|x_{fake})}] LS=E[logP(S=realxreal)]+E[logP(S=fakexfake)]
      • 分类损失: L C = E [ log ⁡ P ( C = c ∣ x r e a l ) ] + E [ log ⁡ P ( C = c ∣ x f a k e ) ] L_C=E[\log{P(C=c|x_{real})}]+E[\log{P(C=c|x_{fake})}] LC=E[logP(C=cxreal)]+E[logP(C=cxfake)]
  • 对生成器而言,希望能够分类正确,但是希望判别器不能分辨数据的真假(由此形成对抗结构)
    • 生成器的损失函数: L D = L C + L S L_D=L_C+L_S LD=LC+LS
      • 判断真假损失: L S = E [ log ⁡ P ( S = f a k e ∣ x f a k e ) ] L_S=E[\log{P(S=fake|x_{fake})}] LS=E[logP(S=fakexfake)]
      • 分类损失: L C = E [ log ⁡ P ( C = c ∣ x f a k e ) ] L_C=E[\log{P(C=c|x_{fake})}] LC=E[logP(C=cxfake)]
基于mnist数据集的ACGAN实现
import tensorflow as tf
from tensorflow import keras 
from tensorflow.keras import layers
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline

(images,labels),(_,_)=keras.datasets.mnist.load_data()
images=np.expand_dims(images,-1)
images=images/127.5 - 1

dataset=tf.data.Dataset.from_tensor_slices((images,labels))

BATCH_SIZE=128
noise_dim=50
BUFFER_SIZE=images.shape[0]

dataset=dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
#-----------------------------------------------------------------------
# 定义生成器
def generate_model():
    seed=layers.Input(shape=(noise_dim,))
    label=layers.Input(())
    
    x=layers.Embedding(10,50,input_length=1)(label)
    x=layers.concatenate([seed,x])
    x=layers.Dense(3*3*128,use_bias=False)(x)
    x=layers.Reshape((3,3,128))(x)          
    x=layers.BatchNormalization()(x)
    x=layers.ReLU()(x)
    
    x=layers.Conv2DTranspose(64,(3,3),strides=2,use_bias=False)(x)
    x=layers.BatchNormalization()(x)
    x=layers.ReLU()(x)
    
    x=layers.Conv2DTranspose(32,(3,3),strides=2,use_bias=False,padding='same')(x)
    x=layers.BatchNormalization()(x)
    x=layers.ReLU()(x)
    
    x=layers.Conv2DTranspose(1,(3,3),strides=2,use_bias=False,padding='same')(x)
    x=layers.Activation('tanh')(x)
    
    model=keras.models.Model(inputs=(seed,label),outputs=x)
    return model
#-----------------------------------------------------------------------
# 定义判别器
def discriminate_model():
    image=layers.Input(shape=(28,28,1))
    
    x=layers.Conv2D(32,(3,3),strides=2,padding='same',use_bias=False)(image)
    x=layes.BatchNormalization()(x)
    x=layers.LeakyReLU()(x)
    x=layers.Dropout(0.5)(x)
    
    x=layers.Conv2D(64,(3,3),strides=2,padding='same',use_bias=False)(x)
    x=layers.BatchNormalization()(x)
    x=layers.LeakyReLU()(x)
    x=layers.Dropout(0.5)(x)
    
    x=layers.Conv2D(128,(3,3),strides=2,padding='same',use_bias=False)(x)
    x=layers.BatchNormalization()(x)
    x=layers.LeakyReLU()(x)
    x=layers.Dropout(0.5)
    
    x=layers.Flatten()(x)
    
    # 真假输出
    out=layes.Dense(1)(x)
    # 分类输出
    classifacation_out=layers.Dense(10)(x)
    model=keras.models.Model(inputs=image,outputs=(out,classifacation_out))
    return model
#-------------------------------------------------------------------------
# 定义损失函数
gen=generate_model()
disc=discriminate_model()

# 由于一方面要判断真假输出,另一方面要判断分类输出,所以损失函数也应该有两个
bce=keras.losses.BinaryCrossentropy(from_logits=True)
cce=keras.losses.SparseCategorialCrossentropy(from_logits=True)

def disc_loss(real_out,real_class_out,fake_out,label):
    real_loss=bce(tf.ones_like(real_out),real_out)
    fake_loss=bce(tf.zeros_like(fake_out),fake_out)
    cat_loss=cce(label,real_class_out)
    total_loss=real_loss+fake_loss+cat_loss
    return total_loss

def gen_loss(fake_out,fake_class_out,label):
    fake_loss=bce(tf.ones_like(fake_out),fake_out)
    cat_loss=cce(label,fake_class_out)
    total_loss=fake_loss+cat_loss
    return total_loss

gen_opt=keras.optimizers.Adam(1e-5)
disc_opt=keras.optimizers.Adam(1e-5)
#------------------------------------------------------------------------
# 自定义训练
def train_step(image,label):
    size=label.shape[0]
    noise=tf.random.normal((size,noise_dim))
    
     with tf.GradientTape() as gen_tape,tf.GradientTape() as disc_tape:
        gen_imgs=gen((noise,label),training=True)
        fake_out,fake_class_out=disc(gen_imgs,training=True)
        real_out,real_class_out=disc(image,training=True)
        
        d_loss=disc_loss(real_out,real_class_out,fake_out,label)
        g_loss=gen_loss(fake_out,fake_class_out,label)
        
    gen_grad=gen_tape.gradient(g_loss,gen.trainable_variables)
    disc_grad=disc_tape.gradient(d_loss,disc.trainable_variables)
    
    gen_opt.apply_gradients(zip(gen_grad,gen.trainable_variables))
    disc_opt.apply_gradients(zip(disc_grad,disc.trainable_variables))
#-----------------------------------------------------------------------
# 自定义绘图函数
def plot_gen_image(model,noise,label,epoch_num):
    print('Epoch:',epoch_num)
    gen_image=model((noise,label),training=False)
    # 压缩维度,将28*28*1 转换成28*28的图像
    gen_image=tf.squeeze(gen_image)  
    fig=plt.figure(figsize=(10,1))
    for i in range(gen_image.shape[0]):
        plt.subplot(1,10,i+1)
        plt.imshow((gen_image[i,:,:]+1)/2,cmap='gray')
        plt.axes('off')
    plt.show()
  • 1
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值