基于keras的CGAN手把手代码教学+注释

基于keras的CGAN手把手代码教学+注释

代码如下


# CGAN_0412 by plus_left
# 2021/4/12

from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import Dense, LeakyReLU, BatchNormalization, Reshape
from tensorflow.keras.layers import Input, Embedding, Flatten, multiply, Dropout
from tensorflow.keras.datasets import mnist
from tensorflow.keras.optimizers import Adam

import matplotlib.pyplot as plt

import numpy as np


class CGAN():
    def __init__(self):

        # 写入输入维度
        self.img_rows = 28
        self.img_cols = 28
        self.img_channels = 1
        self.img_shape = (self.img_rows, self.img_cols, self.img_channels)

        self.num_classes = 10  # 类别数
        self.latent_dim = 100

        optimizer = Adam(0.0002, 0.5)

        self.generator = self.build_generator()
        self.discriminator = self.build_discriminator()
        self.discriminator.compile(loss=['binary_crossentropy'],
                                   optimizer=optimizer,
                                   metrics=['accuracy'])

        self.discriminator.trainable = False
        noise = Input(shape=(100,))
        label = Input(shape=(1,))

        img = self.generator([noise, label])

        valid = self.discriminator([img, label])

        self.combined = Model([noise, label], valid)
        self.combined.compile(loss=['binary_crossentropy'],
                              optimizer=optimizer)

    def build_generator(self):

        model = Sequential()

        model.add(Dense(256, input_dim=self.latent_dim))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(1024))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(np.prod(self.img_shape), activation='tanh'))
        model.add(Reshape(self.img_shape))

        model.summary()  # 记录参数情况

        noise = Input(shape=(self.latent_dim,))
        label = Input(shape=(1,), dtype='int32')

        label_embedding = Flatten()(Embedding(self.num_classes, self.latent_dim)(label))  # 将10个种类的(词向量种类)的label映射到latent_dim维度
        # 将100维转化为(None,100), 这里None会随着batch而改变。

        model_input = multiply([noise, label_embedding])  # 合并方法: 对应位置相乘

        img = model(model_input)  # 生成图片

        return Model([noise, label], img)  # 输入按noise 和 label ,合并由内部完成

    def build_discriminator(self):

        model = Sequential()

        model.add(Dense(512, input_dim=np.prod(self.img_shape)))  # 784个输入神经元
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dropout(0.4))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dropout(0.4))
        model.add(Dense(1, activation='sigmoid'))

        model.summary()

        img = Input(shape=self.img_shape)
        label = Input(shape=(1,), dtype='int32')

        # label 与 img shape不同

        label_embedding = Flatten()(Embedding(self.num_classes, np.prod(self.img_shape))(label))
        # label_embedding shape (None, 784)
        flat_img = Flatten()(img)

        model_input = multiply([flat_img, label_embedding])  # 完成了对应元素相乘 shape (None, 784)

        validity = model(model_input)  # 获取输出概率结果

        return Model([img, label], validity)  # 注意: 合并和维度操作是由模型内部完成

    def train(self, epochs, batch_size = 128, sample_interval = 50):

        # 获取数据集
        (X_train, Y_train, ), (_, _) = mnist.load_data()

        # 将获取到的图像转化为-1 到 1
        X_train = (X_train.astype(np.float32) - 127.5) / 127.5
        X_train = np.expand_dims(X_train, axis=3)
        # 将 60000*28*28维度的图像扩展为 60000*28*28*1

        Y_train = Y_train.reshape(-1, 1)  # -1自动计算第0维它的维度空间数
        # 将Y_train reshape成 60000*1

        # 写入 真实输出 与 虚假输出

        valid = np.ones((batch_size, 1))
        fake = np.zeros((batch_size, 1))

        for epoch in range(epochs):

            # 训练判别器

            # 从0-6w中随机获取batch_size个索引数
            idx = np.random.randint(0, X_train.shape[0], batch_size)
            imgs, labels = X_train[idx], Y_train[idx]
            # 完成了随机获取batch_size个图像以及对应的标签。
            # imgs shape (batch_size, 28, 28 ,1)
            # labels shape (32, 1)

            noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
            # 符合正态分布, shape(batch_size , 100)

            gen_imgs = self.generator.predict([noise, labels])

            d_loss_real = self.discriminator.train_on_batch([imgs, labels], valid)
            d_loss_fake = self.discriminator.train_on_batch([gen_imgs, labels], fake)
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

            # 训练生成器

            sampled_label = np.random.randint(0, 10, batch_size).reshape(-1, 1)

            # 固定鉴别器,训练生成器——在联合模型中

            g_loss = self.combined.train_on_batch([noise, sampled_label], valid)

            print("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))
            # 绘制进度图

            if epoch % sample_interval == 0 :
                self.sample_images(epoch)
            # 完成图像保存

    def sample_images(self, epoch):
        r, c = 2, 5  # 输出 2行5列的10张指定图像

        noise = np.random.normal(0, 1, (r * c, 100))
        sampled_labels = np.arange(0, 10).reshape(-1, 1)

        gen_imgs = self.generator.predict([noise, sampled_labels])

        # Rescale images 0 - 1
        gen_imgs = 0.5 * gen_imgs + 0.5

        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i, j].imshow(gen_imgs[cnt, :, :, 0], cmap='gray')
                axs[i, j].set_title("Digit: %d" % sampled_labels[cnt])
                axs[i, j].axis('off')
                cnt += 1
        fig.savefig("images/%d.png" % epoch)  # images 文件路径和代码文件同目录的
        plt.close()


if __name__ == '__main__':
    cgan = CGAN()
    cgan.train(epochs=20000, batch_size=32, sample_interval=200)

视频手把手教学

代码讲解视频在b站
可以同步去看学习此内容
b站关于此代码讲解视频链接

  • 1
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

plus_left

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值