How to train a GAN model in keras?

13 篇文章 2 订阅
5 篇文章 0 订阅

https://medium.com/dive-into-ml-ai/using-kerass-model-fit-to-train-a-gan-model-a0f02ed6d39e

 

In this article, I present three different methods for training a Discriminator-generator (GAN) model using keras (v2.4.3) on a tensorflow (v2.2.0) backend. These vary in implementation complexity, speed of execution and flexibility. I mention the observations for these methods from these aspects.

Method 1:

Carrying out a batch-wise update of discriminator and generator alternatively, inside nested for loops for epoch and training steps. Most references obtained through an internet search seem to be using this method. The example code is designed with “Data transformation model”. Make necessary tweaks depending on the kind of model you will be requiring.

# Build generator, discriminator and combined models
discriminator = build_discriminator()
generator = build_generator()
combined = build_model(generator, discriminator)# Compile the models
discriminator.compile(...)discrimintor.trainable = False # ensure to do this after compiling
                               # the discriminator model
combine.compile(...)# Get the data generators
train_generator = DataGenerator(mode="train", ...)# Start training
for epoch in range(total_epochs):
    for step in range(train_steps):        # obtain a batch of data
        x_in, x_out= next(train_generator)        # Prepare the labels
        # Will depend on the GAN type
        real_label, fake_label = ....        # Train on original images
        d_loss_real = discriminator.train_on_batch(x_in,   
                      real_label)        # Get batch of generated data from generator
        y_out = generator.predict(x_in)        # Train on generated images
        d_loss_gen = discriminator.train_on_batch(y_out,   
                      fake_label)
        
        # Combined model update (discriminator frozen)
        # mse_loss (ground truth and generated data)
        # g_loss (generator's loss from discriminator)
        mse_loss, g_loss = combined.train_on_batch(x_in,
                                [x_out, real_label])

Note

Depending on GAN type, discriminator can also be trained on a combined batch of “real and generated” data.

Observations

  1. Implementation: simple and straightforward
  2. Speed: slowest
  3. Flexibility: highly flexible, as the model characteristics (trainability of layers, loss_weights, optimizer parameters) can be adjusted before each training step

Method 2:

Alternatively carrying out an epoch-wise update of discriminator or generator, inside a for loop over total epochs. For every epoch, either one of generator or discriminator is kept trainable. The example code is designed for “a general GAN model” . Make necessary tweaks depending on the kind of model you will be requiring.

# Define the data genarator
Class DataGenerator():
   def __init__():
       ...   def __getitem__():       # Prepare the labels
       # Will depend on the GAN type
       # Also on trainability of generator or discriminator
       if epoch%2 == 0:
          labels = ....
       else:
          labels = ....
# Build generator, discriminator and combined models
discriminator = build_discriminator()
generator = build_generator()
combined = build_model(generator, discriminator)# Get the data generators
train_generator = DataGenerator(mode="train", ...)# Start training
for epoch in range(2 * total_epochs):   # Set the trainability
   if epoch%2 == 0:
      generator.trainable = False
      discriminator.trainable = True
   else:
      generator.trainable = True
      discriminator.trainable = False
 
   # Recompile the model
   combine.compile()   # Train the model
    combine.fit(train_generator,                
                initial_epoch=epoch,                
                epochs=epoch + 1,  # Only one epoch run each compile  
               )             

Observations

  1. Implementation: a bit more involved than method 1, combined model needs to be recompiled after every epoch.
  2. Speed: faster than method 1, can utilize internal parallelizations, of model.fit(), over batches
  3. Flexibility: Additional flexibility with use of callbacks
  4. Doubles the number of training epochs,
  5. Starts a new tensorboard profiler session at every new call of model.fit(), which further consumes time
  6. Makes a forward pass over the combined model, even for discriminator update
  7. In my project utilizing this method, I faced the memory leakage issue (regarding handling this issue, please refer my article)
  8. As the trainability of generator and discriminator alternates only once every epoch, possibility of model collapse, with one becoming too good relative to other.

Method 3

Customizing tensorflow’s lower-level API. Tensorflow’s example implementation is provided here . However, the actual implementation needs a few more steps (refer my article for the complete implementation). The example code is designed for “a generator network synthesizing images from random latent vectors”. Make necessary tweaks depending on the kind of model you will be requiring.

class MyModel(KM.Model):
    def train_step(self, real_images):
        if isinstance(real_images, tuple):
            real_images = real_images[0]
        # Sample random points in the latent space
        batch_size = tf.shape(real_images)[0]
        random_latent_vectors = tf.random.normal(shape=(batch_size, 
                                                self.latent_dim))

        # Decode them to fake images
        generated_images = self.generator(random_latent_vectors)

        # Combine them with real images
        combined_images = tf.concat([generated_images, real_images], 
                                    axis=0)

        # Assemble labels discriminating real from fake images
        labels = tf.concat(
            [tf.ones((batch_size, 1)), tf.zeros((batch_size, 1))], 
            axis=0
        )
        # Add random noise to the labels - important trick!
        labels += 0.05 * tf.random.uniform(tf.shape(labels))

        # Train the discriminator
        with tf.GradientTape() as tape:
            predictions = self.discriminator(combined_images)
            d_loss = self.loss_fn(labels, predictions)
        grads = tape.gradient(d_loss, 
                              self.discriminator.trainable_weights)
        self.d_optimizer.apply_gradients(
            zip(grads, self.discriminator.trainable_weights)
        )

        # Sample random points in the latent space
        random_latent_vectors = tf.random.normal(shape=(batch_size, 
                                                 self.latent_dim))

        # Assemble labels that say "all real images"
        misleading_labels = tf.zeros((batch_size, 1))

        # Train the generator 
        # Do not update the discriminator weights
        
        with tf.GradientTape() as tape:
            predictions = self.discriminator(self.generator(
                                     random_latent_vectors))
            g_loss = self.loss_fn(misleading_labels, predictions)
        grads = tape.gradient(g_loss,          
                              self.generator.trainable_weights)
        self.g_optimizer.apply_gradients(zip(grads, 
                                  self.generator.trainable_weights))
        return {"d_loss": d_loss, "g_loss": g_loss

Observations

  1. Implementation: Most involved, would not recommend a beginner to implement this way. Once sufficiently comfortable with keras framework, highly recommend to use this method.
  2. Speed: fastest, much faster compared to second method, only half the number of epochs
  3. Flexibility: Additional flexibility with use of callbacks
  4. Generator and discriminator model can be treated as separate model objects, no need to define a combined model

In this article, I have covered different methods for implementing GAN model training in keras. If you have any suggestions for a better training method or some step you are struggling with, please do let me know.

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值