GAN与WassersteinGAN代码keras分析

本文深入探讨了GAN(生成对抗网络)和Wasserstein GAN的Keras实现,通过分析`train_GAN.py`和`train_WGAN.py`代码,详细展示了如何在mnist数据集上训练这两种模型。同时,`models_WGAN.py`提供了Wasserstein GAN的模型定义。
摘要由CSDN通过智能技术生成

https://github.com/tdeboissiere/DeepLearningImplementations
使用mnist数据集。
train_GAN.py:


        gen_loss = 100
        disc_loss = 100

        # Start training
        print("Start training")
        for e in range(nb_epoch):
        # 每一回合
        # Initialize progbar and batch counter
            progbar = generic_utils.Progbar(epoch_size)
            batch_counter = 1
            start = time.time()

            for X_real_batch in data_utils.gen_batch(X_real_train, batch_size):
# 从X_real_train随机抽取一个batch,循环次数,通过下面
# 达到batch_counter >= n_batch_per_epoch,break控
# 制
                # Create a batch to feed the discriminator model
                X_disc, y_disc = data_utils.get_disc_batch(X_real_batch,
                                                           generator_model,
                                                           batch_counter,
                                                           batch_size,
                                                           noise_dim,
                                                           noise_scale=noise_scale,
                                                           label_smoothing=label_smoothing,
                                                           label_flipping=label_flipping)
# get_disc_batch为判别器生成一个batch数据,X_disc, 
# y_disc分别为数据和标签
# data_utils.get_disc_batch 与WGAN中不同,这里是
# batch_counter奇偶交替,数据为真实数据和生成的数据交替


                # Update the discriminator
                disc_loss = discriminator_model.train_on_batch(X_disc, y_disc)
                # 判别器更新一次             
                # Create a batch to feed the generator model
                X_gen, y_gen = data_utils.get_gen_batch(batch_size, noise_dim, noise_scale=noise_scale)
                # 采样一个batch的噪声给生成器
                # Freeze the discriminator
                discriminator_model.trainable = False
                gen_loss = DCGAN_model.train_on_batch(X_gen, y_gen)
                # 更新一次生成器
                # Unfreeze the discriminator
                discriminator_model.trainable = True

                batch_counter += 1
                progbar.add(batch_size, values=[("D logloss", disc_loss),
                                                ("G logloss", gen_loss)])

                # Save images for visualization
                if batch_counter % 100 == 0:
                    data_utils
  • 2
    点赞
  • 20
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值