ACGAN(Auxiliary Classifier GAN)详解与实现(tensorflow2.x实现)

本文详细介绍了ACGAN(辅助分类生成对抗网络)的工作原理,与CGAN的区别在于,ACGAN的判别器不仅判断图像真假,还能预测图像类别。通过强制网络执行图像分类作为辅助任务来提升生成图像的质量。文章提供了Tensorflow2.x的ACGAN实现,包括生成器和鉴别器的构建,模型训练过程,并展示了生成的虚假图像。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

ACGAN原理

ACGAN的原理GAN(CGAN)相似。对于CGAN和ACGAN,生成器输入均为潜在矢量及其标签,输出是属于输入类标签的伪造图像。对于CGAN,判别器的输入是图像(包含假的或真实的图像)及其标签, 输出是图像属于真实图像的概率。对于ACGAN,判别器的输入是一幅图像,而输出是该图像属于真实图像的概率以及其类别概率。
ACGAN架构本质上,在CGAN中,向网络提供了标签。在ACGAN中,使用辅助解码器网络重建辅助信息。ACGAN理论认为,强制网络执行其他任务可以提高原始任务的性能。在这种情况下,辅助任务是图像分类。原始任务是生成伪造图像。
判别器目标函数:
L ( D ) = − E x ∼ p d a t a l o g D ( x ) − E z l o g [ 1 − D ( G ( z ∣ y ) ) ] − E x ∼ p d a t a p ( c ∣ x ) − E z l o g p ( c ∣ g ( z ∣ y ) ) \mathcal L^{(D)} = -\mathbb E_{x\sim p_{data}}logD(x)-\mathbb E_zlog[1 − D(G(z|y))]-\mathbb E_{x\sim p_{data}}p(c|x)-\mathbb E_zlogp(c|g(z|y)) L(D)=ExpdatalogD(x)Ezlog[1D(G(zy))]Expdatap(cx)Ezlogp(cg(zy))
生成器目标函数:
L ( G ) = − E z l o g D ( g ( z ∣ y ) ) − E z l o g p ( c ∣ g ( z ∣ y ) ) \mathcal L^{(G)} = -\mathbb E_{z}logD(g(z|y))-\mathbb E_zlogp(c|g(z|y)) L(G)=EzlogD(g(zy))Ezlogp(cg(zy))

ACGAN实现

模块导入

import tensorflow as tf
from tensorflow import keras
import numpy as np
from matplotlib import pyplot as plt
import os
import math
from PIL import Image

生成器

def generator(inputs,image_size,activation='sigmoid',labels=None):
    """生成网络
    Arguments:
        inputs (layer): 输入
        image_size (int): 图片尺寸
        activation (string): 输出层激活函数
        labels (tensor): 标签
    returns:
        model: 生成网络
    """
    image_resize = image_size // 4
    kernel_size = 5
    layer_filters = [128,64,32,1]
    inputs = [inputs,labels]
    x = keras.layers.concatenate(inputs,axis=1)
    
    x = keras.layers.Dense(image_resize*image_resize*layer_filters[0])(x)
    x = keras.layers.Reshape((image_resize,image_resize,layer_filters[0]))(x)
    for filters in layer_filters:
        if filters > layer_filters[-2]:
            strides = 2
        else:
            strides = 1
        x = keras.layers.BatchNormalization()(x)
        x = keras.layers.Activation('relu')(x)
        x = keras.layers.Conv2DTranspose(filters=filters,
                kernel_size=kernel_size,
                strides=strides,
                padding='same')(x)
    if activation is not None:
        x = keras.layers.Activation(activation)(x)
    return keras.Model(inputs,x,name='generator')

鉴别器

def discriminator(inputs,activation='sigmoid',num_labels=None):
    """生成网络
    Arguments:
        inputs (Layer): 输入
        activation (string): 输出层激活函数
        num_labels (int): 类别数
    Returns:
        Model: 鉴别网络
    """
    kernel_size = 5
    layer_filters = [32,64,128,256]
    x = inputs
    for filters in layer_filters:
        if filters == layer_filters[-1]:
            strides = 1
        else:
            strides = 2
        x = keras.layers.LeakyReLU(0.2)(x)
        x = keras.layers.Conv2D(filters=filters,
                kernel_size=kernel_size,
                strides=strides,
                padding='same')(x)
    x = keras.layers.Flatten()(x)
    outputs = keras.layers.Dense(1)(x)
    if activation is not None:
        print(activation)
        outputs = keras.layers.Activation(activation)(outputs)
    if num_labels:
        #ACGAN有第二个输出,用于输出图片的类别
        layer = keras.layers.Dense(layer_filters[-2])(x)
        labels = keras.layers.Dense(num_labels)(layer)
        labels = keras.layers.Activation('softmax',name='label')(labels)
        outputs = [outputs,labels]
    return keras.Model(inputs,outputs,name='discriminator')

模型构建

def build_and_train_models():
    """The ACGAN training
    """
    #数据加载及预处理
    (x_train,y_train),_ = keras.datasets.mnist.load_data()
    image_size = x_train.shape[1]
    x_train = np.reshape(x_train,[-1,image_size,image_size,1])
    x_train = x_train.astype('float32') / 255.
    num_labels = len(np.unique(y_train))
    y_train = keras.utils.to_categorical(y_train)

    #超参数
    model_name = 'acgan-mnist'
    latent_size = 100
    batch_size = 64
    train_steps = 40000
    lr = 2e-4
    decay = 6e-8
    input_shape = (image_size,image_size,1)
    label_shape = (num_labels,)
    
    #discriminator
    inputs = keras.layers.Input(shape=input_shape,name='discriminator_input')
    discriminator = discriminator(inputs,num_labels=num_labels)
    optimizer = keras.optimizers.RMSprop(lr=lr,decay=decay)
    loss = ['binary_crossentropy','categorical_crossentropy']
    discriminator.compile(loss=loss,optimizer=optimizer,metrics=['acc'])
    discriminator.summary()

    #generator
    input_shape = (latent_size,)
    inputs = keras.layers.Input(shape=input_shape,name='z_input')
    labels = keras.layers.Input(shape=label_shape,name='labels')
    generator = generator(inputs,image_size,labels=labels)
    generator.summary()
    optimizer = keras.optimizers.RMSprop(lr=lr*0.5,decay=decay*0.5)
    discriminator.trainable = False
    adversarial = keras.Model([inputs,labels],discriminator(generator([inputs,labels])),
            name=model_name)
    adversarial.compile(loss=loss,optimizer=optimizer,metrics=['acc'])
    adversarial.summary()

    models = (generator,discriminator,adversarial)
    data = (x_train,y_train)
    params = (batch_size,latent_size,train_steps,num_labels,model_name)
    train(models,data,params)

模型训练

def train(models,data,params):
    """Train the discriminator and adversarial Networks
    Arguments:
        models (list): generator,discriminator,adversarial
        data (list): x_train,y_train
        params (list): network parameter
    """
    generator,discriminator,adversarial = models
    x_train,y_train = data
    batch_size,latent_size,train_steps,num_labels,model_name = params
    save_interval = 500
    noise_input = np.random.uniform(-1.,1.,size=[16,latent_size])
    noise_label = np.eye(num_labels)[np.arange(0,16) % num_labels]
    train_size = x_train.shape[0]
    print(model_name,'Labels for generated images: ',np.argmax(noise_label,axis=1))
    for i in range(train_steps):
        #训练鉴别器
        rand_indexes = np.random.randint(0,train_size,size=batch_size)
        real_images = x_train[rand_indexes]
        real_labels = y_train[rand_indexes]
        #产生伪造图片
        noise = np.random.uniform(-1.,1.,size=(batch_size,latent_size))
        fake_labels = np.eye(num_labels)[np.random.choice(num_labels,batch_size)]
        fake_images = generator.predict([noise,fake_labels])
        #构造输入
        x = np.concatenate((real_images,fake_images))
        #训练类别标签
        labels = np.concatenate((real_labels,fake_labels))
        #标签
        y = np.ones([2*batch_size,1])
        y[batch_size:,:] = 0.0
        #训练模型
        metrics = discriminator.train_on_batch(x,[y,labels])
        fmt = '%d: [disc loss: %f, srcloss: %f],'
        fmt += 'lbloss: %f, srcacc: %f, lblacc: %f'
        log = fmt % (i,metrics[0],metrics[1],metrics[2],metrics[3],metrics[4])

        #train adversarial network for 1 batch
        noise = np.random.uniform(-1.,1.,size=(batch_size,latent_size))
        fake_labels = np.eye(num_labels)[np.random.choice(num_labels,batch_size)]
        y = np.ones([batch_size,1])
        metrics = adversarial.train_on_batch([noise,fake_labels],[y,fake_labels])
        fmt = "%s [advr loss: %f, srcloss: %f,"
        fmt += "lblloss: %f, srcacc: %f, lblacc: %f]"
        log = fmt % (log, metrics[0], metrics[1], metrics[2], metrics[3], metrics[4])
        print(log)
        if (i + 1) % save_interval == 0:
            # 绘制生成图片
            plot_images(generator,noise_input=noise_input,
                    noise_label=noise_label,show=False,
                    step=(i + 1),
                    model_name=model_name)
    generator.save(model_name + ".h5")

虚假图像生成及绘制plot_images函数

def plot_images(generator,
                noise_input,
                noise_label=None,
                noise_codes=None,
                show=False,
                step=0,
                model_name="gan"):
    """生成虚假图片及绘制

    # Arguments
        generator (Model): 生成模型
        noise_input (ndarray): 潜在模型
        show (bool): 是否展示
        step (int): step值
        model_name (string): 模型名称

    """
    os.makedirs(model_name, exist_ok=True)
    filename = os.path.join(model_name, "%05d.png" % step)
    rows = int(math.sqrt(noise_input.shape[0]))
    if noise_label is not None:
        noise_input = [noise_input, noise_label]
        if noise_codes is not None:
            noise_input += noise_codes

    images = generator.predict(noise_input)
    plt.figure(figsize=(2.2, 2.2))
    num_images = images.shape[0]
    image_size = images.shape[1]
    for i in range(num_images):
        plt.subplot(rows, rows, i + 1)
        image = np.reshape(images[i], [image_size, image_size])
        plt.imshow(image, cmap='gray')
        plt.axis('off')
    plt.savefig(filename)
    if show:
        plt.show()
    else:
        plt.close('all')

训练结果

#运行
if __name__ == '__main__':
    build_and_train_models()
step=1000:

训练1000steps

step=15000:

训练15000steps

评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

盼小辉丶

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值