import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.applications import vgg19
from tensorflow.keras.preprocessing import image as kp_image
from tensorflow.keras import layers

# Load and preprocess image
def load_image(img_path, max_dim=512):
    img = kp_image.load_img(img_path)
    img = kp_image.img_to_array(img)
    img = tf.image.convert_image_dtype(img, dtype=tf.float32)
    img = img[tf.newaxis, :]
    return img

# Define the Style Transfer model
class StyleTransferModel(tf.keras.Model):
    def __init__(self):
        super(StyleTransferModel, self).__init__()
        self.vgg = vgg19.VGG19(weights='imagenet', include_top=False)
        self.vgg.trainable = False
    
    def call(self, inputs):
        return self.vgg(inputs)

def compute_loss(style_weight, content_weight, outputs, style_targets, content_targets):
    style_outputs, content_outputs = outputs
    style_loss = tf.add_n([tf.reduce_mean((style_outputs[name] - style_targets[name]) ** 2) for name in style_outputs])
    content_loss = tf.add_n([tf.reduce_mean((content_outputs[name] - content_targets[name]) ** 2) for name in content_outputs])
    style_loss *= style_weight / len(style_outputs)
    content_loss *= content_weight / len(content_outputs)
    return style_loss + content_loss

# Hyperparameters
style_weight = 1e-2
content_weight = 1e-4
epochs = 10

# Load images
style_image = load_image('style_image.jpg')
content_image = load_image('content_image.jpg')

# Define and compile the model
model = StyleTransferModel()
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.02), loss=lambda *args: compute_loss(style_weight, content_weight, *args))

# Training loop
for epoch in range(epochs):
    model.fit([content_image, style_image], [style_image, content_image], epochs=1)
    if (epoch + 1) % 5 == 0:
        output_image = model(content_image)
        output_image = np.squeeze(output_image)
        output_image = (output_image + 1.0) / 2.0
        plt.imshow(output_image)
        plt.axis('off')
        plt.savefig(f'style_transfer_image_{epoch+1}.png')
        plt.close()
  • 1.
  • 2.
  • 3.
  • 4.
  • 5.
  • 6.
  • 7.
  • 8.
  • 9.
  • 10.
  • 11.
  • 12.
  • 13.
  • 14.
  • 15.
  • 16.
  • 17.
  • 18.
  • 19.
  • 20.
  • 21.
  • 22.
  • 23.
  • 24.
  • 25.
  • 26.
  • 27.
  • 28.
  • 29.
  • 30.
  • 31.
  • 32.
  • 33.
  • 34.
  • 35.
  • 36.
  • 37.
  • 38.
  • 39.
  • 40.
  • 41.
  • 42.
  • 43.
  • 44.
  • 45.
  • 46.
  • 47.
  • 48.
  • 49.
  • 50.
  • 51.
  • 52.
  • 53.
  • 54.
  • 55.
  • 56.
  • 57.