import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.layers import Conv2D, Conv2DTranspose, LeakyReLU, BatchNormalization, Input, Concatenate
from tensorflow.keras.models import Model

def build_generator():
    inputs = Input(shape=(256, 256, 3))
    x = Conv2D(64, 4, strides=2, padding='same', activation='relu')(inputs)
    x = BatchNormalization()(x)
    x = Conv2D(128, 4, strides=2, padding='same', activation='relu')(x)
    x = BatchNormalization()(x)
    x = Conv2D(256, 4, strides=2, padding='same', activation='relu')(x)
    x = BatchNormalization()(x)
    x = Conv2D(512, 4, strides=2, padding='same', activation='relu')(x)
    x = BatchNormalization()(x)
    x = Conv2DTranspose(256, 4, strides=2, padding='same', activation='relu')(x)
    x = BatchNormalization()(x)
    x = Conv2DTranspose(128, 4, strides=2, padding='same', activation='relu')(x)
    x = BatchNormalization()(x)
    x = Conv2DTranspose(64, 4, strides=2, padding='same', activation='relu')(x)
    x = BatchNormalization()(x)
    outputs = Conv2D(3, 4, strides=1, padding='same', activation='tanh')(x)
    return Model(inputs, outputs)

def build_discriminator():
    inputs = Input(shape=(256, 256, 3))
    x = Conv2D(64, 4, strides=2, padding='same', activation='relu')(inputs)
    x = LeakyReLU(alpha=0.2)(x)
    x = Conv2D(128, 4, strides=2, padding='same', activation='relu')(x)
    x = LeakyReLU(alpha=0.2)(x)
    x = Conv2D(256, 4, strides=2, padding='same', activation='relu')(x)
    x = LeakyReLU(alpha=0.2)(x)
    x = Conv2D(512, 4, strides=2, padding='same', activation='relu')(x)
    x = LeakyReLU(alpha=0.2)(x)
    x = Conv2D(1, 4, strides=1, padding='same')(x)
    return Model(inputs, x)

# Hyperparameters
epochs = 20
batch_size = 1

# Load dataset (example)
def load_data():
    # Placeholder function to load dataset
    return np.random.rand(10, 256, 256, 3), np.random.rand(10, 256, 256, 3)

# Initialize models
generator_g = build_generator()
generator_f = build_generator()
discriminator_x = build_discriminator()
discriminator_y = build_discriminator()

# Compile models
generator_g.compile(optimizer=tf.keras.optimizers.Adam(1e-4))
generator_f.compile(optimizer=tf.keras.optimizers.Adam(1e-4))
discriminator_x.compile(optimizer=tf.keras.optimizers.Adam(1e-4))
discriminator_y.compile(optimizer=tf.keras.optimizers.Adam(1e-4))

# Training loop
for epoch in range(epochs):
    real_images_x, real_images_y = load_data()
    fake_images_y = generator_g.predict(real_images_x)
    fake_images_x = generator_f.predict(real_images_y)
    
    # Train Discriminators
    d_loss_real_x = discriminator_x.train_on_batch(real_images_x, np.ones((batch_size, 256, 256, 1)))
    d_loss_fake_x = discriminator_x.train_on_batch(fake_images_x, np.zeros((batch_size, 256, 256, 1)))
    
    d_loss_real_y = discriminator_y.train_on_batch(real_images_y, np.ones((batch_size, 256, 256, 1)))
    d_loss_fake_y = discriminator_y.train_on_batch(fake_images_y, np.zeros((batch_size, 256, 256, 1)))
    
    # Train Generators
    g_loss_g = generator_g.train_on_batch(real_images_x, np.ones((batch_size, 256, 256, 1)))
    g_loss_f = generator_f.train_on_batch(real_images_y, np.ones((batch_size, 256, 256, 1)))
    
    print(f'Epoch [{epoch+1}/{epochs}], D Loss X Real: {d_loss_real_x}, D Loss X Fake: {d_loss_fake_x}, D Loss Y Real: {d_loss_real_y}, D Loss Y Fake: {d_loss_fake_y}')
    print(f'G Loss G: {g_loss_g}, G Loss F: {g_loss_f}')
    
    if (epoch + 1) % 5 == 0:
        output_images = generator_g.predict(real_images_x)
        for i in range(batch_size):
            plt.imshow(output_images[i])
            plt.axis('off')
            plt.savefig(f'cyclegan_image_{epoch+1}_{i}.png')
            plt.close()
  • 1.
  • 2.
  • 3.
  • 4.
  • 5.
  • 6.
  • 7.
  • 8.
  • 9.
  • 10.
  • 11.
  • 12.
  • 13.
  • 14.
  • 15.
  • 16.
  • 17.
  • 18.
  • 19.
  • 20.
  • 21.
  • 22.
  • 23.
  • 24.
  • 25.
  • 26.
  • 27.
  • 28.
  • 29.
  • 30.
  • 31.
  • 32.
  • 33.
  • 34.
  • 35.
  • 36.
  • 37.
  • 38.
  • 39.
  • 40.
  • 41.
  • 42.
  • 43.
  • 44.
  • 45.
  • 46.
  • 47.
  • 48.
  • 49.
  • 50.
  • 51.
  • 52.
  • 53.
  • 54.
  • 55.
  • 56.
  • 57.
  • 58.
  • 59.
  • 60.
  • 61.
  • 62.
  • 63.
  • 64.
  • 65.
  • 66.
  • 67.
  • 68.
  • 69.
  • 70.
  • 71.
  • 72.
  • 73.
  • 74.
  • 75.
  • 76.
  • 77.
  • 78.
  • 79.
  • 80.
  • 81.
  • 82.
  • 83.
  • 84.
  • 85.
  • 86.