import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, models, losses, optimizers
# Function to create masked images and their corresponding originals
def create_masked_images(images, mask_fraction=0.2):
masks = np.random.rand(*images.shape) < mask_fraction
masked_images = images * masks
return masked_images, images
# Create a simple image inpainting model
def create_inpainting_model(input_shape):
model = models.Sequential()
model.add(layers.Conv2D(32, (3, 3), activation='relu', padding='same', input_shape=input_shape))
model.add(layers.Conv2D(64, (3, 3), activation='relu', padding='same'))
model.add(layers.Conv2D(32, (3, 3), activation='relu', padding='same'))
model.add(layers.Conv2D(1, (3, 3), activation='sigmoid', padding='same'))
return model
# Load your dataset here or use a sample dataset (e.g., CIFAR-10)
# X_train, Y_train = load_dataset()
# Assuming your dataset is loaded and preprocessed
# Normalize the images to [0, 1] range and create masked versions
(X_train, _), (X_test, _) = tf.keras.datasets.mnist.load_data()
# Normalize the images to [0, 1] range and create masked versions
X_train = X_train.astype('float32') / 255.0
X_test = X_test.astype('float32') / 255.0
X_masked_train, Y_train = create_masked_images(X_train)
X_masked_test, Y_test = create_masked_images(X_test)
# Add channel dimension for grayscale images
X_masked_train = np.expand_dims(X_masked_train, axis=-1)
X_masked_test = np.expand_dims(X_masked_test, axis=-1)
Y_train = np.expand_dims(Y_train, axis=-1)
Y_test = np.expand_dims(Y_test, axis=-1)
# Create the inpainting model
input_shape = X_masked_train[0].shape
model = create_inpainting_model(input_shape)
# Compile the model
model.compile(optimizer='adam', loss='mse')
# Train the model
model.fit(X_masked_train, Y_train, epochs=10, batch_size=32, validation_data=(X_masked_test, Y_test))
# Inference: Given a masked image, get the restored image
def restore_image(model, masked_image):
restored_image = model.predict(np.expand_dims(masked_image, axis=0))
return np.squeeze(restored_image, axis=0)
# Test the model with a masked image from the test set
masked_image = X_masked_test[0]
restored_image = restore_image(model, masked_image)
# Display the images (masked vs. restored)
import matplotlib.pyplot as plt
plt.subplot(1, 2, 1)
plt.title("Masked Image")
plt.imshow(masked_image.squeeze(), cmap='gray')
plt.subplot(1, 2, 2)
plt.title("Restored Image")
plt.imshow(restored_image.squeeze(), cmap='gray')
plt.show()
【mask_demo2】
最新推荐文章于 2024-05-19 23:31:32 发布