https://medium.com/dive-into-ml-ai/using-kerass-model-fit-to-train-a-gan-model-a0f02ed6d39e
In this article, I present three different methods for training a Discriminator-generator (GAN) model using keras (v2.4.3) on a tensorflow (v2.2.0) backend. These vary in implementation complexity, speed of execution and flexibility. I mention the observations for these methods from these aspects.
Method 1:
Carrying out a batch-wise update of discriminator and generator alternatively, inside nested for loops for epoch and training steps. Most references obtained through an internet search seem to be using this method. The example code is designed with “Data transformation model”. Make necessary tweaks depending on the kind of model you will be requiring.
# Build generator, discriminator and combined models discriminator = build_discriminator() generator = build_generator() combined = build_model(generator, discriminator)# Compile the models discriminator.compile(...)discrimintor.trainable = False # ensure to do this after compiling # the discriminator model combine.compile(...)# Get the data generators train_generator = DataGenerator(mode="train", ...)# Start training for epoch in range(total_epochs): for step in range(train_steps): # obtain a batch of data x_in, x_out= next(train_generator) # Prepare the labels # Will depend on the GAN type real_label, fake_label = .... # Train on original images d_loss_real = discriminator.train_on_batch(x_in, real_label) # Get batch of generated data from generator y_out = generator.predict(x_in) # Train on generated images d_loss_gen = discriminator.train_on_batch(y_out, fake_label) # Combined model update (discriminator frozen) # mse_loss (ground truth and generated data) # g_loss (generator's loss from discriminator) mse_loss, g_loss = combined.train_on_batch(x_in, [x_out, real_label])
Note
Depending on GAN type, discriminator can also be trained on a combined batch of “real and generated” data.
Observations
- Implementation: simple and straightforward
- Speed: slowest
- Flexibility: highly flexible, as the model characteristics (trainability of layers, loss_weights, optimizer parameters) can be adjusted before each training step
Method 2:
Alternatively carrying out an epoch-wise update of discriminator or generator, inside a for loop over total epochs. For every epoch, either one of generator or discriminator is kept trainable. The example code is designed for “a general GAN model” . Make necessary tweaks depending on the kind of model you will be requiring.
# Define the data genarator Class DataGenerator(): def __init__(): ... def __getitem__(): # Prepare the labels # Will depend on the GAN type # Also on trainability of generator or discriminator if epoch%2 == 0: labels = .... else: labels = .... # Build generator, discriminator and combined models discriminator = build_discriminator() generator = build_generator() combined = build_model(generator, discriminator)# Get the data generators train_generator = DataGenerator(mode="train", ...)# Start training for epoch in range(2 * total_epochs): # Set the trainability if epoch%2 == 0: generator.trainable = False discriminator.trainable = True else: generator.trainable = True discriminator.trainable = False # Recompile the model combine.compile() # Train the model combine.fit(train_generator, initial_epoch=epoch, epochs=epoch + 1, # Only one epoch run each compile )
Observations
- Implementation: a bit more involved than method 1, combined model needs to be recompiled after every epoch.
- Speed: faster than method 1, can utilize internal parallelizations, of model.fit(), over batches
- Flexibility: Additional flexibility with use of callbacks
- Doubles the number of training epochs,
- Starts a new tensorboard profiler session at every new call of model.fit(), which further consumes time
- Makes a forward pass over the combined model, even for discriminator update
- In my project utilizing this method, I faced the memory leakage issue (regarding handling this issue, please refer my article)
- As the trainability of generator and discriminator alternates only once every epoch, possibility of model collapse, with one becoming too good relative to other.
Method 3
Customizing tensorflow’s lower-level API. Tensorflow’s example implementation is provided here . However, the actual implementation needs a few more steps (refer my article for the complete implementation). The example code is designed for “a generator network synthesizing images from random latent vectors”. Make necessary tweaks depending on the kind of model you will be requiring.
class MyModel(KM.Model): def train_step(self, real_images): if isinstance(real_images, tuple): real_images = real_images[0] # Sample random points in the latent space batch_size = tf.shape(real_images)[0] random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim)) # Decode them to fake images generated_images = self.generator(random_latent_vectors) # Combine them with real images combined_images = tf.concat([generated_images, real_images], axis=0) # Assemble labels discriminating real from fake images labels = tf.concat( [tf.ones((batch_size, 1)), tf.zeros((batch_size, 1))], axis=0 ) # Add random noise to the labels - important trick! labels += 0.05 * tf.random.uniform(tf.shape(labels)) # Train the discriminator with tf.GradientTape() as tape: predictions = self.discriminator(combined_images) d_loss = self.loss_fn(labels, predictions) grads = tape.gradient(d_loss, self.discriminator.trainable_weights) self.d_optimizer.apply_gradients( zip(grads, self.discriminator.trainable_weights) ) # Sample random points in the latent space random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim)) # Assemble labels that say "all real images" misleading_labels = tf.zeros((batch_size, 1)) # Train the generator # Do not update the discriminator weights with tf.GradientTape() as tape: predictions = self.discriminator(self.generator( random_latent_vectors)) g_loss = self.loss_fn(misleading_labels, predictions) grads = tape.gradient(g_loss, self.generator.trainable_weights) self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights)) return {"d_loss": d_loss, "g_loss": g_loss
Observations
- Implementation: Most involved, would not recommend a beginner to implement this way. Once sufficiently comfortable with keras framework, highly recommend to use this method.
- Speed: fastest, much faster compared to second method, only half the number of epochs
- Flexibility: Additional flexibility with use of callbacks
- Generator and discriminator model can be treated as separate model objects, no need to define a combined model
In this article, I have covered different methods for implementing GAN model training in keras. If you have any suggestions for a better training method or some step you are struggling with, please do let me know.