CGAN论文详解与代码详解

来源 | 极链AI云(性价比最高的共享GPU算力平台,新人注册可领取198元大礼包,价值100+小时的GPU免费使用时长,领取地址:https://cloud.videojj.com/

 

本篇博客我们将介绍CGAN(条件GAN)论文的相关细节。CGAN的论文网址请移步:https://cloud.videojj.com/bbs/category/2/general-discussion

一、 GAN回顾

为了兼顾CGAN的相关理论介绍,我们首先回顾GAN相关细节。GAN主要包括两个网络,一个是生成器$G$和判别器$D$,生成器的目的就是将随机输入的高斯噪声映射成图像(“假图”),判别器则是判断输入图像是否来自生成器的概率,即判断输入图像是否为假图的概率。

在这里我们假设数据为$x$,生成器的数据分布为$p_g$,噪声分布为$p_z(z)$,那么噪声$z$的结果可以记作$G(z;\theta_g)$,数据$x$在判别器$D$上的结果为$D(x;\theta_d)$。

那么GAN的目的就是无中生有,以假乱真。即要使得生成器$G$生成的所谓的"假图"骗过判别器$D$,那么最优状态就是生成器$G$生成的所谓的"假图"在判别器$D$的判别结果为0.5,不知道到底是真图还是假图。GAN的目标函数如下:
$$
\underset{G}{\mathop{\min }},\underset{D}{\mathop{\max }},V(D,G)={{\mathbb{E}}{x\sim {{p}{data}}(x)}}[\log D(x)]+{{\mathbb{E}}{z\sim {{p}{data}}(z)}}[\log (1-D(G(z)))]\tag1
$$

 

二、CGAN网络架构详解

在介绍CGAN的原理接下来介绍了CGAN的相关原理。原始的GAN的生成器只能根据随机噪声进行生成图像,至于这个图像是什么(即标签是什么我们无从得知),判别器也只能接收图像输入进行判别是否图像来使生成器。因此CGAN的主要贡献就是在原始GAN的生成器与判别器中的输入中加入额外信息$y$。额外信息$y$可以是任何信息,例如标签。因此CGAN的提出使得GAN可以利用图像与对应的标签进行训练,并在测试阶段
利用给定标签生成特定图像。

在CGAN的论文中,网络架构使用的MLP(全连接网络)。在CGAN中的生成器,我们给定一个输入噪声$p_z(z)$和额外信息$y$,之后将两者通过全连接层连接到一起,作为隐藏层输入。同样地,在判别器中输入图像$x$和 额外信息$y$也将连接到一起作为隐藏层输入。CGAN的网络架构图如下所示:

https://img-blog.csdnimg.cn/20191108102239799.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9kYWlwdXdlaWFpLmJsb2cuY3Nkbi5uZXQ=,size_16,color_FFFFFF,t_70#pic_center

 

那么,CGAN的目标函数可以表述成如下形式:
$$
\underset{G}{\mathop{\min }},\underset{D}{\mathop{\max }},V(D,G)={{\mathbb{E}}{x\sim {{p}{data}}(x)}}[\log D(x|y)]+{{\mathbb{E}}{z\sim {{p}{data}}(z)}}[\log (1-D(G(z|y)))]\tag2
$$

下面是CGAN论文中生成的手写数字图像的结果,每一行代表有一个标签,例如第一行代表标签为0的图片。

https://img-blog.csdnimg.cn/20191108102955170.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9kYWlwdXdlaWFpLmJsb2cuY3Nkbi5uZXQ=,size_16,color_FFFFFF,t_70#pic_center

 

 

三、CGAN-MNIST代码详解

接下来我们将主要介绍CGAN生成手写数字图像的keras代码。github链接为:CGAN-mnist。首先给出CGAN的网络架构代码:

import datetime
import matplotlib.pyplot as plt

from scipy.stats import truncnorm


from keras import Input
from keras import Model
from keras import Sequential

from keras.layers import Dense
from keras.layers import Activation
from keras.layers import Reshape
from keras.layers import Conv2DTranspose
from keras.layers import BatchNormalization
from keras.layers import Conv2D
from keras.layers import LeakyReLU
from keras.layers import Dropout
from keras.layers import Flatten
from keras.layers.merge import multiply
from keras.layers.merge import concatenate
from keras.layers.merge import add
from keras.layers import Embedding
from keras.utils import to_categorical
from keras.optimizers import Adam
from keras.utils.generic_utils import Progbar
from copy import deepcopy
from keras.datasets import mnist

def make_trainable(net, val):
    """ Freeze or unfreeze layers
    """
    net.trainable = val
    for l in net.layers: l.trainable = val

class CGAN(object):

    def __init__(self,config,weight_path=None):
        """
        这是CGAN的初始化函数
        :param config: 参数配置类实例
        :param weight_path: 权重文件地址,默认为None
        """
        self.config = config
        self.build_cgan_model()

        if weight_path is not None:
            self.cgan.load_weights(weight_path,by_name=True)

    def build_cgan_model(self):
        """
        这是搭建CGAN模型的函数
        :return:
        """
        # 初始化输入
        self.generator_noise_input = Input(shape=(self.config.generator_noise_input_dim,))
        self.condational_label_input = Input(shape=(1,), dtype='int32')
        self.discriminator_image_input = Input(shape=self.config.discriminator_image_input_dim)

        # 定义优化器
        self.optimizer = Adam(lr=2e-4, beta_1=0.5)

        # 构建生成器模型与判别器模型
        self.discriminator_model = self.build_discriminator_model()
        self.discriminator_model.compile(optimizer=self.optimizer, loss=['binary_crossentropy'],metrics=['accuracy'])
        self.generator_model = self.build_generator()

        # 构建CGAN模型
        self.discriminator_model.trainable = False
        self.cgan_input = [self.generator_noise_input,self.condational_label_input]
        generator_output = self.generator_model(self.cgan_input)
        cgan_output = self.discriminator_model([generator_output,self.condational_label_input])
        self.cgan = Model(self.cgan_input,cgan_output)

        # 编译
        #self.discriminator_model.compile(optimizer=self.optimizer,loss='binary_crossentropy')
        self.cgan.compile(optimizer=self.optimizer,loss=['binary_crossentropy'])

    def build_discriminator_model(self):
        """
        这是搭建生成器模型的函数
        :return:
        """
        model = Sequential()

        model.add(Dense(512, input_dim=np.prod(self.config.discriminator_image_input_dim)))
        model.add(LeakyReLU(alpha=self.config.LeakyReLU_alpha))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=self.config.LeakyReLU_alpha))
        model.add(Dropout(self.config.LeakyReLU_alpha))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=self.config.LeakyReLU_alpha))
        model.add(Dropout(self.config.LeakyReLU_alpha))
        model.add(Dense(1, activation='sigmoid'))
        model.summary()

        img = Input(shape=self.config.discriminator_image_input_dim)
        label = Input(shape=(1,), dtype='int32')

        label_embedding = Flatten()(Embedding(self.config.condational_label_num,
                                              np.prod(self.config.discriminator_image_input_dim))(label))
        flat_img = Flatten()(img)
        model_input = multiply([flat_img, label_embedding])
        validity = model(model_input)

        return Model([img, label], validity)


    def build_generator(self):
        """
        这是构建生成器网络的函数
        :return:返回生成器模型generotor_model
        """
        model = Sequential()

        model.add(Dense(256, input_dim=self.config.generator_noise_input_dim))
        model.add(LeakyReLU(alpha=self.config.LeakyReLU_alpha))
        model.add(BatchNormalization(momentum=self.config.batchnormalization_momentum))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=self.config.LeakyReLU_alpha))
        model.add(BatchNormalization(momentum=self.config.batchnormalization_momentum))
        model.add(Dense(1024))
        model.add(LeakyReLU(alpha=self.config.LeakyReLU_alpha))
        model.add(BatchNormalization(momentum=self.config.batchnormalization_momentum))
        model.add(Dense(np.prod(self.config.discriminator_image_input_dim), activation='tanh'))
        model.add(Reshape(self.config.discriminator_image_input_dim))

        model.summary()

        noise = Input(shape=(self.config.generator_noise_input_dim,))
        label = Input(shape=(1,), dtype='int32')
        label_embedding = Flatten()(Embedding(self.config.condational_label_num, self.config.generator_noise_input_dim)(label))

        model_input = multiply([noise, label_embedding])
        img = model(model_input)

        return Model([noise, label], img)

    def train(self, train_datagen, epoch, k, batch_size=256):
        """
        这是DCGAN的训练函数
        :param train_generator:训练数据生成器
        :param epoch:周期数
        :param batch_size:小批量样本规模
        :param k:训练判别器次数
        :return:
        """
        time =datetime.datetime.now().strftime("%Y%m%d%H%M%S")
        model_path = os.path.join(self.config.model_dir,time)
        if not os.path.exists(model_path):
            os.mkdir(model_path)

        train_result_path = os.path.join(self.config.train_result_dir,time)
        if not os.path.exists(train_result_path):
            os.mkdir(train_result_path)

        for ep in np.arange(1, epoch+1).astype(np.int32):
            cgan_losses = []
            d_losses = []
            # 生成进度条
            length = train_datagen.batch_num
            progbar = Progbar(length)
            print('Epoch {}/{}'.format(ep, epoch))
            iter = 0
            while True:
                # 遍历一次全部数据集,那么重新来结束while循环
                #print("iter:{},{}".format(iter,train_datagen.get_epoch() != ep))
                if train_datagen.epoch != ep:
                    break

                # 获取真实图片,并构造真图对应的标签
                batch_real_images, batch_real_labels = train_datagen.next_batch()
                batch_real_num_labels = np.ones((batch_size, 1))
                #batch_real_num_labels = truncnorm.rvs(0.7, 1.2, size=(batch_size, 1))
                # 初始化随机噪声,伪造假图,并合并真图和假图数据集
                batch_noises = np.random.normal(0, 1, size = (batch_size, self.config.generator_noise_input_dim))
                d_loss = []
                for i in np.arange(k):
                    # 构造假图标签,合并真图和假图对应标签
                    batch_fake_num_labels = np.zeros((batch_size,1))
                    #batch_fake_num_labels = truncnorm.rvs(0.0, 0.3, size=(batch_size, 1))
                    batch_fake_labels = deepcopy(batch_real_labels)
                    batch_fake_images = self.generator_model.predict([batch_noises,batch_fake_labels])

                    # 训练判别器
                    real_d_loss = self.discriminator_model.train_on_batch([batch_real_images,batch_real_labels],
                                                                                      batch_real_num_labels)
                    fake_d_loss = self.discriminator_model.train_on_batch([batch_fake_images, batch_fake_labels],
                                                                                      batch_fake_num_labels)
                    d_loss.append(list(0.5*np.add(real_d_loss,fake_d_loss)))
                #print(d_loss)
                d_losses.append(list(np.average(d_loss,0)))
                #print(d_losses)

                # 生成一个batch_size的噪声来训练生成器
                #batch_num_labels = truncnorm.rvs(0.7, 1.2, size=(batch_size, 1))
                batch_num_labels = np.ones((batch_size,1))
                batch_labels = np.random.randint(0, 10, batch_size).reshape(-1, 1)
                cgan_loss = self.cgan.train_on_batch([batch_noises,batch_labels], batch_num_labels)
                cgan_losses.append(cgan_loss)

                # 更新进度条
                progbar.update(iter, [('dcgan_loss', cgan_losses[iter]),
                                      ('discriminator_loss',d_losses[iter][0]),
                                      ('acc',d_losses[iter][1])])
                #print("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (ep, d_losses[ep][0], 100 * d_losses[ep][1],cgan_loss))
                iter += 1
            if ep % self.config.save_epoch_interval == 0:
                model_cgan = "Epoch{}dcgan_loss{}discriminator_loss{}acc{}.h5".format(ep, np.average(cgan_losses),
                                                                                      np.average(d_losses,0)[0],np.average(d_losses,0)[1])
                self.cgan.save(os.path.join(model_path, model_cgan))
                save_dir = os.path.join(train_result_path, str("Epoch{}".format(ep)))
                if not os.path.exists(save_dir):
                    os.mkdir(save_dir)
                self.save_image(int(ep), save_dir)
            '''
            if int(ep) in self.config.generate_image_interval:
                save_dir = os.path.join(train_result_path,str("Epoch{}".format(ep)))
                if not os.path.exists(save_dir):
                    os.mkdir(save_dir)
                self.save_image(ep,save_dir)
            '''
        plt.plot(np.arange(epoch),cgan_losses,'b-','cgan-loss')
        plt.plot(np.arange(epoch), d_losses[0], 'b-', 'd-loss')
        plt.grid(True)
        plt.legend(locs="best")
        plt.xlabel("Epoch")
        plt.ylabel("Loss")
        plt.savefig(os.path.join(train_result_path,"loss.png"))

    def save_image(self, epoch,save_path):
        """
        这是保存生成图片的函数
        :param epoch:周期数
        :param save_path: 图片保存地址
        :return:
        """
        rows, cols = 10, 10

        fig, axs = plt.subplots(rows, cols)
        for i in range(rows):
            label = np.array([i]*rows).astype(np.int32).reshape(-1,1)
            noise = np.random.normal(0, 1, (cols, 100))
            images = self.generator_model.predict([noise,label])
            images = 127.5*images+127.5
            cnt = 0
            for j in range(cols):
                #img_path = os.path.join(save_path, str(cnt) + ".png")
                #cv2.imwrite(img_path, images[cnt])
                #axs[i, j].imshow(image.astype(np.int32)[:,:,0])
                axs[i, j].imshow(images[cnt,:, :, 0].astype(np.int32), cmap='gray')
                axs[i, j].axis('off')
                cnt += 1
        fig.savefig(os.path.join(save_path, "mnist-{}.png".format(epoch)), dpi=600)
        plt.close()

    def generate_image(self,label):
        """
        这是伪造一张图片的函数
        :param label:标签
        """
        noise = truncnorm.rvs(-1, 1, size=(1, self.config.generator_noise_input_dim))
        label = np.array([label]).T
        image = self.generator_model.predict([noise,label])[0]
        image = 127.5*(image+1)
        return image

为了训练我们必须还的构造一个数据集迭代器来读取小批量手写数字图像数据,数据集迭代器类的代码如下:

# -*- coding: utf-8 -*-
# @Time    : 2019/10/8 17:29
# @Author  : Dai PuWei
# @File    : MnistGenerator.py
# @Software: PyCharm

import math
import numpy as np
from keras.datasets import mnist

class MnistGenerator(object):

    def __init__(self,batch_size):
        """
        这是图像数据生成器的初始化函数
        :param batch_size: 小批量样本规模
        """
        (x_train,y_train),(x_test,y_test) = mnist.load_data()
        #self.x = np.concatenate([x_train,x_test]).astype(np.float32)
        self.x = np.expand_dims((x_train.astype(np.float32)-127.5)/127.5,axis=-1)
        #self.y = to_categorical(np.concatenate([y_train,y_test]),num_classes=10)
        self.y = y_train.reshape(-1,1)
        #self.y = self.y[y == ]
        #print(np.shape(self.x))
        #print(np.shape(self.y))
        self.images_size = len(self.x)
        random_index = np.random.permutation(np.arange(self.images_size))
        self.x = self.x[random_index]
        self.y = self.y[random_index]

        self.epoch = 1                                  # 当前迭代次数
        self.batch_size = int(batch_size)
        self.batch_num = math.ceil(self.images_size / self.batch_size)
        self.start = 0
        self.end = 0
        self.finish_flag = False                        # 数据集是否遍历完一次标志

    def _next_batch(self):
        """
        :return:
        """
        while True:
            #batch_images = np.array([])
            #batch_labels = np.array([])
            if self.finish_flag:  # 数据集遍历完一次
                random_index = np.random.permutation(np.arange(self.images_size))
                self.x = self.x[random_index]
                self.y = self.y[random_index]
                self.finish_flag = False
                self.epoch += 1
            self.end = int(np.min([self.images_size,self.start+self.batch_size]))
            batch_images = self.x[self.start:self.end]
            batch_labels = self.y[self.start:self.end]
            batch_size = self.end - self.start
            if self.end == self.images_size:            # 数据集刚分均分
                self.finish_flag = True
            if batch_size < self.batch_size:        # 小批次规模小于与预定规模,基本上是最后一组
                random_index = np.random.permutation(np.arange(self.images_size))
                self.x = self.x[random_index]
                self.y = self.y[random_index]
                batch_images = np.concatenate((batch_images, self.x[0:self.batch_size - batch_size]))
                batch_labels = np.concatenate((batch_labels, self.y[0:self.batch_size - batch_size]))
                self.start = self.batch_size - batch_size
                self.epoch += 1
            else:
                self.start = self.end
            yield batch_images,batch_labels

    def next_batch(self):
        datagen = self._next_batch()
        return datagen.__next__()

下面是相关训练CGAN的代码:

# -*- coding: utf-8 -*-
# @Time    : 2019/10/8 15:43
# @Author  : Dai PuWei
# @File    : train.py
# @Software: PyCharm

import os
import datetime

from CGAN.CGAN import CGAN
from Config.Config import MnistConfig
from DataGenerator.MnistGenerator import MnistGenerator

def run_main():
    """
    这是主函数
    """
    cfg =  MnistConfig()
    cgan = CGAN(cfg)
    batch_size = 512
    #train_datagen = Cifar10Generator(int(batch_size/2))
    train_datagen = MnistGenerator(batch_size)
    cgan.train(train_datagen,100000,1,batch_size)


if __name__ == '__main__':
    run_main()

下面是训练过程中的CGAN的生成的手写数字图像。第1个epoch之后的生成结果:

https://img-blog.csdnimg.cn/20191108112828632.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9kYWlwdXdlaWFpLmJsb2cuY3Nkbi5uZXQ=,size_16,color_FFFFFF,t_70#pic_center

第10个epoch之后的生成结果:

https://img-blog.csdnimg.cn/20191108112934197.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9kYWlwdXdlaWFpLmJsb2cuY3Nkbi5uZXQ=,size_16,color_FFFFFF,t_70#pic_center

第100个epoch之后的生成结果:
 

https://img-blog.csdnimg.cn/20191108112958402.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9kYWlwdXdlaWFpLmJsb2cuY3Nkbi5uZXQ=,size_16,color_FFFFFF,t_70#pic_center

第1000个epoch之后的生成结果:

https://img-blog.csdnimg.cn/20191108113037329.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9kYWlwdXdlaWFpLmJsb2cuY3Nkbi5uZXQ=,size_16,color_FFFFFF,t_70#pic_center

下面是CGAN的测试代码:

# -*- coding: utf-8 -*-
# @Time    : 2019/11/8 13:11
# @Author  : DaiPuWei
# @Email   : daipuwei@qq.com
# @File    : test.py
# @Software: PyCharm


import os
from CGAN.CGAN import CGAN
from Config.Config import MnistConfig

def run_main():
    """
    这是主函数
    """
    weight_path = os.path.abspath("./model/20191009134644/Epoch1378dcgan_loss1.5952800512313843discriminator_loss[0.49839333 0.7379193 ]acc[0.49839333 0.7379193 ].h5")
    result_path = os.path.abspath("./test_result")
    if not os.path.exists(result_path):
        os.mkdir(result_path)
    cfg =  MnistConfig()
    cgan = CGAN(cfg,weight_path)
    cgan.save_image(0,result_path)


if __name__ == '__main__':
    run_main()

 

来源 | 极链AI云(性价比最高的共享GPU算力平台,新人注册可领取198元大礼包,价值100+小时的GPU免费使用时长,领取地址:https://cloud.videojj.com/

 

  • 0
    点赞
  • 19
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
CGAN是一种条件生成对抗网络,它可以根据给定的条件生成具有所需特征的图像。以下是一个简单的CGAN代码示例,其中使用MNIST数据集生成手写数字图像: 首先,我们需要导入所需的库和数据集: ``` import tensorflow as tf from tensorflow.keras import layers import numpy as np import matplotlib.pyplot as plt # 加载MNIST数据集 (train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data() train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32') train_images = (train_images - 127.5) / 127.5 # 归一化为[-1, 1]的范围 ``` 接下来,我们定义生成器和判别器模型: ``` # 生成器模型 def make_generator_model(): model = tf.keras.Sequential() model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,))) model.add(layers.BatchNormalization()) model.add(layers.LeakyReLU()) model.add(layers.Reshape((7, 7, 256))) assert model.output_shape == (None, 7, 7, 256) # 注意:batch size没有限制 model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False)) assert model.output_shape == (None, 7, 7, 128) model.add(layers.BatchNormalization()) model.add(layers.LeakyReLU()) model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False)) assert model.output_shape == (None, 14, 14, 64) model.add(layers.BatchNormalization()) model.add(layers.LeakyReLU()) model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh')) assert model.output_shape == (None, 28, 28, 1) return model # 判别器模型 def make_discriminator_model(): model = tf.keras.Sequential() model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same', input_shape=[28, 28, 1])) model.add(layers.LeakyReLU()) model.add(layers.Dropout(0.3)) model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same')) model.add(layers.LeakyReLU()) model.add(layers.Dropout(0.3)) model.add(layers.Flatten()) model.add(layers.Dense(1)) return model ``` 然后,我们定义损失函数和优化器: ``` # 定义损失函数和优化器 cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True) def discriminator_loss(real_output, fake_output): real_loss = cross_entropy(tf.ones_like(real_output), real_output) fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output) total_loss = real_loss + fake_loss return total_loss def generator_loss(fake_output): return cross_entropy(tf.ones_like(fake_output), fake_output) generator_optimizer = tf.keras.optimizers.Adam(1e-4) discriminator_optimizer = tf.keras.optimizers.Adam(1e-4) ``` 接下来,我们定义训练循环: ``` # 定义训练循环 @tf.function def train_step(images, labels): noise = tf.random.normal([BATCH_SIZE, 100]) with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape: generated_images = generator([noise, labels], training=True) real_output = discriminator([images, labels], training=True) fake_output = discriminator([generated_images, labels], training=True) gen_loss = generator_loss(fake_output) disc_loss = discriminator_loss(real_output, fake_output) gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables) gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables) generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables)) discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables)) ``` 最后,我们开始训练模型,生成手写数字图像: ``` # 开始训练模型 EPOCHS = 50 BATCH_SIZE = 256 num_examples_to_generate = 16 # 我们将重复使用该种子(因此在动画 GIF 中更容易可视化进度) seed = tf.random.normal([num_examples_to_generate, 100]) # 将标签与种子一起生成图像 def generate_and_save_images(model, epoch, test_input, labels): # 注意 training` 设定为 False # 因此,所有层都在推理模式下运行(batchnorm)。 predictions = model([test_input, labels], training=False) fig = plt.figure(figsize=(4, 4)) for i in range(predictions.shape[0]): plt.subplot(4, 4, i+1) plt.imshow((predictions[i, :, :, 0] + 1) / 2, cmap='gray') plt.axis('off') plt.savefig('image_at_epoch_{:04d}.png'.format(epoch)) plt.show() generator = make_generator_model() discriminator = make_discriminator_model() for epoch in range(EPOCHS): for i in range(train_images.shape[0] // BATCH_SIZE): images = train_images[i*BATCH_SIZE:(i+1)*BATCH_SIZE] labels = train_labels[i*BATCH_SIZE:(i+1)*BATCH_SIZE] train_step(images, labels) if epoch % 5 == 0: generate_and_save_images(generator, epoch, seed, np.array([0,1,2,3,4,5,6,7,8,9])) ``` 这段代码中,我们使用了条件变量labels,它可以让我们控制生成的图像具有所需的特征,例如生成特定数字的图像。在训练循环中,我们使用了两个损失函数:判别器损失和生成器损失。在每个epoch结束时,我们生成一些手写数字图像并保存它们。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值