VGG-19
import keras
import math
import numpy as np
from keras.datasets import cifar10
from keras.preprocessing.image import ImageDataGenerator
from keras.layers.normalization import BatchNormalization
from keras.layers import Conv2D, Dense, Input, add, Activation, AveragePooling2D, GlobalAveragePooling2D, MaxPooling2D, Flatten, Dropout
from keras.layers import Lambda, concatenate
from keras.initializers import he_normal
from keras.layers.merge import Concatenate
from keras.callbacks import LearningRateScheduler, TensorBoard, ModelCheckpoint
from keras.models import Model
from keras import optimizers
from keras import regularizers
from keras.utils import plot_model
import time
growth_rate = 12
depth = 100
compression = 0.5
img_rows, img_cols = 32, 32
img_channels = 1
num_classes = 10
batch_size = 128
epochs = 200
iterations = 391
dropout = 0.5
weight_decay = 0.0001
log_filepath = r'./vgg19_retrain_logs/'
def scheduler(epoch):
if epoch < 40:
return 0.1
if epoch < 100:
return 0.01
if epoch < 150:
return 0.001
return 0.0001
import os
import tensorflow as tf
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
from keras.backend.tensorflow_backend import set_session
config = tf.ConfigProto()
config.gpu_options.per_process_gpu_memory_fraction = 0.4
set_session(tf.Session(config=config))
mean = [125.307, 122.95, 113.865]
std = [62.9932, 62.0887, 66.7048]
## --------------------------------------------------------------------------
def vgg19(img_input,classes_num):
# Block 1
x = Conv2D(64, kernel_size=(3, 3), strides=(1, 1), padding='same', kernel_initializer=he_normal(),
kernel_regularizer=regularizers.l2(weight_decay), use_bias=False)(img_input)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Conv2D(64, kernel_size=(3, 3), strides=(1, 1), padding='same', kernel_initializer=he_normal(),
kernel_regularizer=regularizers.l2(weight_decay), use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = MaxPooling2D((2, 2), strides=(2, 2))(x)
# Block 2
x = Conv2D(128, kernel_size=(3, 3), strides=(1, 1), padding='same', kernel_initializer=he_normal(),
kernel_regularizer=regularizers.l2(weight_decay), use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Conv2D(128, kernel_size=(3, 3), strides=(1, 1), padding='same', kernel_initializer=he_normal(),
kernel_regularizer=regularizers.l2(weight_decay), use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = MaxPooling2D((2, 2), strides=(2, 2))(x)
# Block 3
x = Conv2D(256, kernel_size=(3, 3), strides=(1, 1), padding='same', kernel_initializer=he_normal(),
kernel_regularizer=regularizers.l2(weight_decay), use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Conv2D(256, kernel_size=(3, 3), strides=(1, 1), padding='same', kernel_initializer=he_normal(),
kernel_regularizer=regularizers.l2(weight_decay), use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Conv2D(256, kernel_size=(3, 3), strides=(1, 1), padding='same', kernel_initializer=he_normal(),
kernel_regularizer=regularizers.l2(weight_decay), use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Conv2D(256, kernel_size=(3, 3), strides=(1, 1), padding='same', kernel_initializer=he_normal(),
kernel_regularizer=regularizers.l2(weight_decay), use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = MaxPooling2D((2, 2), strides=(2, 2))(x)
# Block 4
x = Conv2D(512, kernel_size=(3, 3), strides=(1, 1), padding='same', kernel_initializer=he_normal(),
kernel_regularizer=regularizers.l2(weight_decay), use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Conv2D(512, kernel_size=(3, 3), strides=(1, 1), padding='same', kernel_initializer=he_normal(),
kernel_regularizer=regularizers.l2(weight_decay), use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Conv2D(512, kernel_size=(3, 3), strides=(1, 1), padding='same', kernel_initializer=he_normal(),
kernel_regularizer=regularizers.l2(weight_decay), use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Conv2D(512, kernel_size=(3, 3), strides=(1, 1), padding='same', kernel_initializer=he_normal(),
kernel_regularizer=regularizers.l2(weight_decay), use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = MaxPooling2D((2, 2), strides=(2, 2))(x)
# Block 5
x = Conv2D(512, kernel_size=(3, 3), strides=(1, 1), padding='same', kernel_initializer=he_normal(),
kernel_regularizer=regularizers.l2(weight_decay), use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Conv2D(512, kernel_size=(3, 3), strides=(1, 1), padding='same', kernel_initializer=he_normal(),
kernel_regularizer=regularizers.l2(weight_decay), use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Conv2D(512, kernel_size=(3, 3), strides=(1, 1), padding='same', kernel_initializer=he_normal(),
kernel_regularizer=regularizers.l2(weight_decay), use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Conv2D(512, kernel_size=(3, 3), strides=(1, 1), padding='same', kernel_initializer=he_normal(),
kernel_regularizer=regularizers.l2(weight_decay), use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = MaxPooling2D((2, 2), strides=(2, 2))(x)
# model modification for cifar-10
x=Flatten()(x)
x=Dense(4096, use_bias=True, kernel_regularizer=keras.regularizers.l2(weight_decay),
kernel_initializer=he_normal())(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x= Dropout(dropout)(x)
x=Dense(4096, kernel_regularizer=keras.regularizers.l2(weight_decay),
kernel_initializer=he_normal())(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x= Dropout(dropout)(x)
x=Dense(10, kernel_regularizer=keras.regularizers.l2(weight_decay),
kernel_initializer=he_normal())(x)
x = BatchNormalization()(x)
x = Activation('softmax')(x)
return x
if __name__ == '__main__':
# load data
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
# - mean / std
for i in range(3):
x_train[:,:,:,i] = (x_train[:,:,:,i] - mean[i]) / std[i]
x_test[:,:,:,i] = (x_test[:,:,:,i] - mean[i]) / std[i]
# time start
time_start = time.time()
# build network
img_input = Input(shape=(img_rows, img_cols, img_channels))
output = vgg19(img_input, num_classes)
model = Model(img_input, output)
# model.load_weights('ckpt.h5')
# plot_model(model, show_shapes=True, to_file='model.png')
print(model.summary())
# set optimizer
sgd = optimizers.SGD(lr=.1, momentum=0.9, nesterov=True)
model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=['accuracy'])
# set callback
tb_cb = TensorBoard(log_dir='./vgg19_newlog_def/', histogram_freq=0)
change_lr = LearningRateScheduler(scheduler)
ckpt = ModelCheckpoint('./ckpt.h5', save_best_only=False, mode='auto', period=10)
cbks = [change_lr, tb_cb, ckpt]
# set data augmentation
print('Using real-time data augmentation.')
datagen = ImageDataGenerator(horizontal_flip=True, width_shift_range=0.125, height_shift_range=0.125,
fill_mode='constant', cval=0.)
datagen.fit(x_train)
# start training
model.fit_generator(datagen.flow(x_train, y_train, batch_size=batch_size), steps_per_epoch=iterations,
epochs=epochs, callbacks=cbks, validation_data=(x_test, y_test))
# time end
time_end = time.time()
print("The time has passed %d s" % (time_end - time_start))
model.save('vgg19.h5')
resnet
import keras
import numpy as np
from keras.datasets import cifar10
from keras.preprocessing.image import ImageDataGenerator
from keras.layers.normalization import BatchNormalization
from keras.layers import Conv2D, Dense, Input, add, Activation, GlobalAveragePooling2D
from keras.callbacks import LearningRateScheduler, TensorBoard, ModelCheckpoint
from keras.models import Model
from keras import optimizers, regularizers
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
stack_n = 5
num_classes = 10
img_rows, img_cols = 32, 32
img_channels = 3
batch_size = 128
epochs = 200
iterations = 50000 // batch_size
weight_decay = 0.0001
mean = [125.307, 122.95, 113.865]
std = [62.9932, 62.0887, 66.7048]
def scheduler(epoch):
if epoch < 80:
return 0.1
if epoch < 150:
return 0.01
return 0.001
def residual_network(img_input,classes_num=10,stack_n=5):
def residual_block(intput,out_channel,increase=False):
if increase:
stride = (2,2)
else:
stride = (1,1)
pre_bn = BatchNormalization()(intput)
pre_relu = Activation('relu')(pre_bn)
conv_1 = Conv2D(out_channel,kernel_size=(3,3),strides=stride,padding='same',
kernel_initializer="he_normal",
kernel_regularizer=regularizers.l2(weight_decay))(pre_relu)
bn_1 = BatchNormalization()(conv_1)
relu1 = Activation('relu')(bn_1)
conv_2 = Conv2D(out_channel,kernel_size=(3,3),strides=(1,1),padding='same',
kernel_initializer="he_normal",
kernel_regularizer=regularizers.l2(weight_decay))(relu1)
if increase:
projection = Conv2D(out_channel,
kernel_size=(1,1),
strides=(2,2),
padding='same',
kernel_initializer="he_normal",
kernel_regularizer=regularizers.l2(weight_decay))(intput)
block = add([conv_2, projection])
else:
block = add([intput,conv_2])
return block
# build model
# total layers = stack_n * 3 * 2 + 2
# stack_n = 5 by default, total layers = 32
# input: 32x32x3 output: 32x32x16
x = Conv2D(filters=16,kernel_size=(3,3),strides=(1,1),padding='same',
kernel_initializer="he_normal",
kernel_regularizer=regularizers.l2(weight_decay))(img_input)
# input: 32x32x16 output: 32x32x16
for _ in range(stack_n):
x = residual_block(x,16,False)
# input: 32x32x16 output: 16x16x32
x = residual_block(x,32,True)
for _ in range(1,stack_n):
x = residual_block(x,32,False)
# input: 16x16x32 output: 8x8x64
x = residual_block(x,64,True)
for _ in range(1,stack_n):
x = residual_block(x,64,False)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = GlobalAveragePooling2D()(x)
# input: 64 output: 10
x = Dense(classes_num,activation='softmax',
kernel_initializer="he_normal",
kernel_regularizer=regularizers.l2(weight_decay))(x)
return x
def color_preprocessing(x_train,x_test):
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
for i in range(3):
x_train[:,:,:,i] = (x_train[:,:,:,i] - mean[i]) / std[i]
x_test[:,:,:,i] = (x_test[:,:,:,i] - mean[i]) / std[i]
return x_train, x_test
if __name__ == '__main__':
# load data
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)
x_train = np.load('../bayer_sysm_x_train.npy')
x_test = np.load('../bayer_sysm_x_test.npy')
# color preprocessing
# x_train, x_test = color_preprocessing(x_train, x_test)
# build network
img_input = Input(shape=(img_rows,img_cols,img_channels))
output = residual_network(img_input,num_classes,stack_n)
resnet = Model(img_input, output)
print(resnet.summary())
# set optimizer
sgd = optimizers.SGD(lr=.1, momentum=0.9, nesterov=True)
resnet.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=['accuracy'])
# set callback
cbks = [TensorBoard(log_dir='./resnet_32/', histogram_freq=0),
LearningRateScheduler(scheduler),
ModelCheckpoint('./checkpoint-{epoch}.h5', save_best_only=False, mode='auto', period=10)]
# set data augmentation
print('Using real-time data augmentation.')
datagen = ImageDataGenerator(horizontal_flip=True,
width_shift_range=0.125,
height_shift_range=0.125,
fill_mode='constant',cval=0.)
datagen.fit(x_train)
# start training
resnet.fit_generator(datagen.flow(x_train, y_train,batch_size=batch_size),
steps_per_epoch=iterations,
epochs=epochs,
callbacks=cbks,
validation_data=(x_test, y_test))
resnet.save('resnet.h5')
gan
from __future__ import print_function, division
from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam
import matplotlib.pyplot as plt
import sys
import numpy as np
class GAN():
def __init__(self):
self.img_rows = 28
self.img_cols = 28
self.channels = 1
self.img_shape = (self.img_rows, self.img_cols, self.channels)
self.latent_dim = 100
optimizer = Adam(0.0002, 0.5)
# Build and compile the discriminator
self.discriminator = self.build_discriminator()
self.discriminator.compile(loss='binary_crossentropy',
optimizer=optimizer,
metrics=['accuracy'])
# Build the generator
self.generator = self.build_generator()
# The generator takes noise as input and generates imgs
z = Input(shape=(self.latent_dim,))
img = self.generator(z)
# For the combined model we will only train the generator
self.discriminator.trainable = False
# The discriminator takes generated images as input and determines validity
validity = self.discriminator(img)
# The combined model (stacked generator and discriminator)
# Trains the generator to fool the discriminator
self.combined = Model(z, validity)
self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)
def build_generator(self):
model = Sequential()
model.add(Dense(256, input_dim=self.latent_dim))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(512))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(1024))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(np.prod(self.img_shape), activation='tanh'))
model.add(Reshape(self.img_shape))
model.summary()
noise = Input(shape=(self.latent_dim,))
img = model(noise)
return Model(noise, img)
def build_discriminator(self):
model = Sequential()
model.add(Flatten(input_shape=self.img_shape))
model.add(Dense(512))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(256))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(1, activation='sigmoid'))
model.summary()
img = Input(shape=self.img_shape)
validity = model(img)
return Model(img, validity)
def train(self, epochs, batch_size=128, sample_interval=50):
# Load the dataset
(X_train, _), (_, _) = mnist.load_data()
# Rescale -1 to 1
X_train = X_train / 127.5 - 1.
X_train = np.expand_dims(X_train, axis=3)
# Adversarial ground truths
valid = np.ones((batch_size, 1))
fake = np.zeros((batch_size, 1))
for epoch in range(epochs):
# ---------------------
# Train Discriminator
# ---------------------
# Select a random batch of images
idx = np.random.randint(0, X_train.shape[0], batch_size)
imgs = X_train[idx]
noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
# Generate a batch of new images
gen_imgs = self.generator.predict(noise)
# Train the discriminator
d_loss_real = self.discriminator.train_on_batch(imgs, valid)
d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
# ---------------------
# Train Generator
# ---------------------
noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
# Train the generator (to have the discriminator label samples as valid)
g_loss = self.combined.train_on_batch(noise, valid)
# Plot the progress
print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))
# If at save interval => save generated image samples
if epoch % sample_interval == 0:
self.sample_images(epoch)
def sample_images(self, epoch):
r, c = 5, 5
noise = np.random.normal(0, 1, (r * c, self.latent_dim))
gen_imgs = self.generator.predict(noise)
# Rescale images 0 - 1
gen_imgs = 0.5 * gen_imgs + 0.5
fig, axs = plt.subplots(r, c)
cnt = 0
for i in range(r):
for j in range(c):
axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
axs[i,j].axis('off')
cnt += 1
fig.savefig("images/%d.png" % epoch)
plt.close()
if __name__ == '__main__':
gan = GAN()
gan.train(epochs=30000, batch_size=32, sample_interval=200)
SRgan
"""
Super-resolution of CelebA using Generative Adversarial Networks.
The dataset can be downloaded from: https://www.dropbox.com/sh/8oqt9vytwxb3s4r/AADIKlz8PR9zr6Y20qbkunrba/Img/img_align_celeba.zip?dl=0
Instrustion on running the script:
1. Download the dataset from the provided link
2. Save the folder 'img_align_celeba' to 'datasets/'
4. Run the sript using command 'python srgan.py'
"""
from __future__ import print_function, division
import scipy
from keras.datasets import mnist
from keras_contrib.layers.normalization import InstanceNormalization
from keras.layers import Input, Dense, Reshape, Flatten, Dropout, Concatenate
from keras.layers import BatchNormalization, Activation, ZeroPadding2D, Add
from keras.layers.advanced_activations import PReLU, LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.applications import VGG19
from keras.models import Sequential, Model
from keras.optimizers import Adam
import datetime
import matplotlib.pyplot as plt
import sys
from data_loader import DataLoader
import numpy as np
import os
import keras.backend as K
class SRGAN():
def __init__(self):
# Input shape
self.channels = 3
self.lr_height = 64 # Low resolution height
self.lr_width = 64 # Low resolution width
self.lr_shape = (self.lr_height, self.lr_width, self.channels)
self.hr_height = self.lr_height*4 # High resolution height
self.hr_width = self.lr_width*4 # High resolution width
self.hr_shape = (self.hr_height, self.hr_width, self.channels)
# Number of residual blocks in the generator
self.n_residual_blocks = 16
optimizer = Adam(0.0002, 0.5)
# We use a pre-trained VGG19 model to extract image features from the high resolution
# and the generated high resolution images and minimize the mse between them
self.vgg = self.build_vgg()
self.vgg.trainable = False
self.vgg.compile(loss='mse',
optimizer=optimizer,
metrics=['accuracy'])
# Configure data loader
self.dataset_name = 'img_align_celeba'
self.data_loader = DataLoader(dataset_name=self.dataset_name,
img_res=(self.hr_height, self.hr_width))
# Calculate output shape of D (PatchGAN)
patch = int(self.hr_height / 2**4)
self.disc_patch = (patch, patch, 1)
# Number of filters in the first layer of G and D
self.gf = 64
self.df = 64
# Build and compile the discriminator
self.discriminator = self.build_discriminator()
self.discriminator.compile(loss='mse',
optimizer=optimizer,
metrics=['accuracy'])
# Build the generator
self.generator = self.build_generator()
# High res. and low res. images
img_hr = Input(shape=self.hr_shape)
img_lr = Input(shape=self.lr_shape)
# Generate high res. version from low res.
fake_hr = self.generator(img_lr)
# Extract image features of the generated img
fake_features = self.vgg(fake_hr)
# For the combined model we will only train the generator
self.discriminator.trainable = False
# Discriminator determines validity of generated high res. images
validity = self.discriminator(fake_hr)
self.combined = Model([img_lr, img_hr], [validity, fake_features])
self.combined.compile(loss=['binary_crossentropy', 'mse'],
loss_weights=[1e-3, 1],
optimizer=optimizer)
def build_vgg(self):
"""
Builds a pre-trained VGG19 model that outputs image features extracted at the
third block of the model
"""
vgg = VGG19(weights="imagenet")
# Set outputs to outputs of last conv. layer in block 3
# See architecture at: https://github.com/keras-team/keras/blob/master/keras/applications/vgg19.py
vgg.outputs = [vgg.layers[9].output]
img = Input(shape=self.hr_shape)
# Extract image features
img_features = vgg(img)
return Model(img, img_features)
def build_generator(self):
def residual_block(layer_input, filters):
"""Residual block described in paper"""
d = Conv2D(filters, kernel_size=3, strides=1, padding='same')(layer_input)
d = Activation('relu')(d)
d = BatchNormalization(momentum=0.8)(d)
d = Conv2D(filters, kernel_size=3, strides=1, padding='same')(d)
d = BatchNormalization(momentum=0.8)(d)
d = Add()([d, layer_input])
return d
def deconv2d(layer_input):
"""Layers used during upsampling"""
u = UpSampling2D(size=2)(layer_input)
u = Conv2D(256, kernel_size=3, strides=1, padding='same')(u)
u = Activation('relu')(u)
return u
# Low resolution image input
img_lr = Input(shape=self.lr_shape)
# Pre-residual block
c1 = Conv2D(64, kernel_size=9, strides=1, padding='same')(img_lr)
c1 = Activation('relu')(c1)
# Propogate through residual blocks
r = residual_block(c1, self.gf)
for _ in range(self.n_residual_blocks - 1):
r = residual_block(r, self.gf)
# Post-residual block
c2 = Conv2D(64, kernel_size=3, strides=1, padding='same')(r)
c2 = BatchNormalization(momentum=0.8)(c2)
c2 = Add()([c2, c1])
# Upsampling
u1 = deconv2d(c2)
u2 = deconv2d(u1)
# Generate high resolution output
gen_hr = Conv2D(self.channels, kernel_size=9, strides=1, padding='same', activation='tanh')(u2)
return Model(img_lr, gen_hr)
def build_discriminator(self):
def d_block(layer_input, filters, strides=1, bn=True):
"""Discriminator layer"""
d = Conv2D(filters, kernel_size=3, strides=strides, padding='same')(layer_input)
d = LeakyReLU(alpha=0.2)(d)
if bn:
d = BatchNormalization(momentum=0.8)(d)
return d
# Input img
d0 = Input(shape=self.hr_shape)
d1 = d_block(d0, self.df, bn=False)
d2 = d_block(d1, self.df, strides=2)
d3 = d_block(d2, self.df*2)
d4 = d_block(d3, self.df*2, strides=2)
d5 = d_block(d4, self.df*4)
d6 = d_block(d5, self.df*4, strides=2)
d7 = d_block(d6, self.df*8)
d8 = d_block(d7, self.df*8, strides=2)
d9 = Dense(self.df*16)(d8)
d10 = LeakyReLU(alpha=0.2)(d9)
validity = Dense(1, activation='sigmoid')(d10)
return Model(d0, validity)
def train(self, epochs, batch_size=1, sample_interval=50):
start_time = datetime.datetime.now()
for epoch in range(epochs):
# ----------------------
# Train Discriminator
# ----------------------
# Sample images and their conditioning counterparts
imgs_hr, imgs_lr = self.data_loader.load_data(batch_size)
# From low res. image generate high res. version
fake_hr = self.generator.predict(imgs_lr)
valid = np.ones((batch_size,) + self.disc_patch)
fake = np.zeros((batch_size,) + self.disc_patch)
# Train the discriminators (original images = real / generated = Fake)
d_loss_real = self.discriminator.train_on_batch(imgs_hr, valid)
d_loss_fake = self.discriminator.train_on_batch(fake_hr, fake)
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
# ------------------
# Train Generator
# ------------------
# Sample images and their conditioning counterparts
imgs_hr, imgs_lr = self.data_loader.load_data(batch_size)
# The generators want the discriminators to label the generated images as real
valid = np.ones((batch_size,) + self.disc_patch)
# Extract ground truth image features using pre-trained VGG19 model
image_features = self.vgg.predict(imgs_hr)
# Train the generators
g_loss = self.combined.train_on_batch([imgs_lr, imgs_hr], [valid, image_features])
elapsed_time = datetime.datetime.now() - start_time
# Plot the progress
print ("%d time: %s" % (epoch, elapsed_time))
# If at save interval => save generated image samples
if epoch % sample_interval == 0:
self.sample_images(epoch)
def sample_images(self, epoch):
os.makedirs('images/%s' % self.dataset_name, exist_ok=True)
r, c = 2, 2
imgs_hr, imgs_lr = self.data_loader.load_data(batch_size=2, is_testing=True)
fake_hr = self.generator.predict(imgs_lr)
# Rescale images 0 - 1
imgs_lr = 0.5 * imgs_lr + 0.5
fake_hr = 0.5 * fake_hr + 0.5
imgs_hr = 0.5 * imgs_hr + 0.5
# Save generated images and the high resolution originals
titles = ['Generated', 'Original']
fig, axs = plt.subplots(r, c)
cnt = 0
for row in range(r):
for col, image in enumerate([fake_hr, imgs_hr]):
axs[row, col].imshow(image[row])
axs[row, col].set_title(titles[col])
axs[row, col].axis('off')
cnt += 1
fig.savefig("images/%s/%d.png" % (self.dataset_name, epoch))
plt.close()
# Save low resolution images for comparison
for i in range(r):
fig = plt.figure()
plt.imshow(imgs_lr[i])
fig.savefig('images/%s/%d_lowres%d.png' % (self.dataset_name, epoch, i))
plt.close()
if __name__ == '__main__':
gan = SRGAN()
gan.train(epochs=30000, batch_size=1, sample_interval=50)
srgan++changed
"""
Super-resolution of CelebA using Generative Adversarial Networks.
The dataset can be downloaded from: https://www.dropbox.com/sh/8oqt9vytwxb3s4r/AADIKlz8PR9zr6Y20qbkunrba/Img/img_align_celeba.zip?dl=0
Instrustion on running the script:
1. Download the dataset from the provided link
2. Save the folder 'img_align_celeba' to 'datasets/'
4. Run the sript using command 'python srgan.py'
"""
from __future__ import print_function, division
import scipy
from keras.datasets import mnist
# from keras_contrib.layers.normalization import InstanceNormalization
from keras.layers import Input, Dense, Reshape, Flatten, Dropout, Concatenate
from keras.layers import BatchNormalization, Activation, ZeroPadding2D, Add
from keras.layers.advanced_activations import PReLU, LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.applications import VGG19
from keras.models import Sequential, Model
from keras.optimizers import Adam
import datetime
import matplotlib.pyplot as plt
import sys
from data_loader import DataLoader
import numpy as np
import os
from PIL import Image
from keras.layers import Conv2D, Dense, Input, add, Activation, AveragePooling2D, GlobalAveragePooling2D, MaxPooling2D, Flatten, Dropout
from subpixel import *
import keras.backend as K
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
global batch_size
batch_size = 20
class SRGAN():
def __init__(self):
# Input shape
self.channels = 3
self.lr_height = 128 # Low resolution height
self.lr_width = 128 # Low resolution width
self.lr_shape = (self.lr_height, self.lr_width, 4)
self.hr_height = self.lr_height*2 # High resolution height
self.hr_width = self.lr_width*2 # High resolution width
self.hr_shape = (self.hr_height, self.hr_width, 3)
# Number of residual blocks in the generator
self.n_residual_blocks = 16
optimizer = Adam(0.0002, 0.5)
# We use a pre-trained VGG19 model to extract image features from the high resolution
# and the generated high resolution images and minimize the mse between them
# Configure data loader
self.dataset_name = 'train_img_crop'
self.data_loader = DataLoader(dataset_name=self.dataset_name,
img_res=(self.hr_height, self.hr_width))
# Calculate output shape of D (PatchGAN)
patch = int(self.hr_height / 2**4)
self.disc_patch = (patch, patch, 1)
# Number of filters in the first layer of G and D
self.gf = 64
self.df = 64
# Build and compile the discriminator
self.discriminator = self.build_discriminator()
self.discriminator.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy'])
# Build and compile the generator
self.generator = self.build_generator()
self.generator.compile(loss='binary_crossentropy', optimizer=optimizer)
# High res. and low res. images
img_hr = Input(shape=self.hr_shape)
img_lr = Input(shape=self.lr_shape)
# Generate high res. version from low res.
fake_hr = self.generator(img_lr)
# Extract image features of the generated img
# For the combined model we will only train the generator
self.discriminator.trainable = False
# Discriminator determines validity of generated high res. images
validity = self.discriminator(fake_hr)
self.combined = Model(img_lr, [validity, fake_hr])
self.combined.compile(loss=['binary_crossentropy', 'mse'],
loss_weights=[1e-3, 1],
optimizer=optimizer)
def build_generator(self):
def bn_relu(x):
x = BatchNormalization()(x)
x = Activation('relu')(x)
return x
def slice(x, i):
""" Define a tensor slice function
"""
return x[:, :, :, i]
# Low resolution image input
img_lr = Input(shape=self.lr_shape)
n = Conv2D(filters=64, kernel_size=(3, 3), strides=(1, 1), padding='same',
kernel_initializer="he_normal")(img_lr)
n = Activation('relu')(n)
tmp = n
for i in range(16):
nn = Conv2D(filters=64, kernel_size=(3, 3), strides=(1, 1), padding='same',
kernel_initializer="he_normal")(n)
nn = bn_relu(nn)
nn = Conv2D(filters=64, kernel_size=(3, 3), strides=(1, 1), padding='same',
kernel_initializer="he_normal")(nn)
nn = BatchNormalization()(nn)
n = add([n, nn])
n = Conv2D(filters=64, kernel_size=(3, 3), strides=(1, 1), padding='same',
kernel_initializer="he_normal")(n)
n = BatchNormalization()(n)
n = add([tmp, n])
n = Conv2D(filters=256, kernel_size=(3, 3), strides=(1, 1), padding='same',
kernel_initializer="he_normal")(n)
n = SubpixelConv2D((batch_size, 16, 16, 256), scale=2)(n)
n = Activation('relu')(n)
n = Conv2D(filters=3, kernel_size=(1, 1), strides=(1, 1), padding='same',
kernel_initializer="he_normal", name='outputs')(n)
gen_hr = Activation('tanh')(n)
return Model(img_lr, gen_hr)
def build_discriminator(self):
def d_block(layer_input, filters, strides=1, bn=True):
"""Discriminator layer"""
d = Conv2D(filters, kernel_size=3, strides=strides, padding='same')(layer_input)
d = LeakyReLU(alpha=0.2)(d)
if bn:
d = BatchNormalization(momentum=0.8)(d)
return d
# Input img
d0 = Input(shape=self.hr_shape)
d1 = d_block(d0, self.df, bn=False)
d2 = d_block(d1, self.df, strides=2)
d3 = d_block(d2, self.df*2)
d4 = d_block(d3, self.df*2, strides=2)
d5 = d_block(d4, self.df*4)
d6 = d_block(d5, self.df*4, strides=2)
d7 = d_block(d6, self.df*8)
d8 = d_block(d7, self.df*8, strides=2)
d9 = Dense(self.df*16)(d8)
d10 = LeakyReLU(alpha=0.2)(d9)
validity = Dense(1, activation='sigmoid')(d10)
return Model(d0, validity)
def train(self, epochs, batch_size=1, sample_interval=50):
start_time = datetime.datetime.now()
for epoch in range(epochs):
# ----------------------
# Train Discriminator
# ----------------------
for idx in range(self.data_loader.count_batch_num()//batch_size):
# Sample images and their conditioning counterparts
imgs_hr, imgs_lr = self.data_loader.load_data(batch_size, idx=idx)
# From low res. image generate high res. version
fake_hr = self.generator.predict(imgs_lr)
valid = np.ones((batch_size,) + self.disc_patch)
fake = np.zeros((batch_size,) + self.disc_patch)
# Train the discriminators (original images = real / generated = Fake)
d_loss_real = self.discriminator.train_on_batch(imgs_hr, valid)
d_loss_fake = self.discriminator.train_on_batch(fake_hr, fake)
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
print('epoch %d:Discriminator batch %d loss: %s' % (epoch, idx, d_loss))
# ------------------
# Train Generator
# ------------------
for idx in range(self.data_loader.count_batch_num()//batch_size):
# Sample images and their conditioning counterparts
imgs_hr, imgs_lr = self.data_loader.load_data(batch_size, idx=idx)
# The generators want the discriminators to label the generated images as real
valid = np.ones((batch_size,) + self.disc_patch)
# Extract ground truth image features using pre-trained VGG19 model
# Train the generators
g_loss = self.combined.train_on_batch(imgs_lr, [valid, imgs_hr])
print('epoch %d:Generator batch %d loss: %s' % (epoch, idx, g_loss))
elapsed_time = datetime.datetime.now() - start_time
# Plot the progress
print("%d time: %s" % (epoch, elapsed_time))
# If at save interval => save generated image samples
if epoch % 5 == 0:
self.sample_images(epoch)
if epoch % 10 == 0:
self.test_images(epoch)
def sample_images(self, epoch):
if os.path.exists('images/%s' % self.dataset_name) == False:
os.makedirs('images/%s' % self.dataset_name)
r, c = 2, 2
imgs_hr, imgs_lr = self.data_loader.load_data(batch_size=2, idx=0, is_testing=True)
fake_hr = self.generator.predict(imgs_lr)
# Rescale images 0 - 1
imgs_lr = 0.5 * imgs_lr + 0.5
fake_hr = 0.5 * fake_hr + 0.5
imgs_hr = 0.5 * imgs_hr + 0.5
# Save generated images and the high resolution originals
titles = ['Generated', 'Original']
fig, axs = plt.subplots(r, c)
cnt = 0
for row in range(r):
for col, image in enumerate([fake_hr, imgs_hr]):
axs[row, col].imshow(image[row])
axs[row, col].set_title(titles[col])
axs[row, col].axis('off')
cnt += 1
fig.savefig("images/changed_plot3/%d.png" % (epoch))
plt.close()
# Save low resolution images for comparison
for i in range(r):
fig = plt.figure()
plt.imshow(imgs_lr[i])
fig.savefig('images/changed_plot3/%d_lowres%d.png' % ( epoch, i))
plt.close()
def psnr(self, im1, im2):
diff = np.abs(im1 - im2)
[w, h, channel] = diff.shape
sqrt_img = np.sqrt(diff)
mse = np.square(diff).sum() / w / h / channel
psnr = 10 * np.log10(255 * 255 / mse)
return psnr
def test_images(self, epoch):
pics_test_real, imgs_lr = self.data_loader.load_test_data_crop()
pics_test_predict = self.generator.predict(imgs_lr)
pics_test_real = (pics_test_real+1)*127.5
pics_test_predict = (pics_test_predict+1)*127.5
res_psnr = 0.0
for i in range(len(pics_test_real)):
out = pics_test_predict[i, :, :, :]
out = out.round()
gt = pics_test_real[i, :, :, :]
cur_psnr = self.psnr(gt, out)
out = pics_test_predict[i, :, :, :]
im = Image.fromarray(out.astype('uint8'))
im.save('./images/test_result3/changed_result_epoch_' + str(epoch) + '_di' + str(i) + '=' + '.png')
# im.show()
# print('psnr = ' + str(cur_psnr))
res_psnr += cur_psnr
res_psnr = res_psnr / len(pics_test_real)
print(' ---epoch ' + str(epoch) + ' is testing:' + ' average psnr = ' + str(res_psnr))
if __name__ == '__main__':
gan = SRGAN()
gan.train(epochs=300, batch_size=batch_size, sample_interval=10)
import scipy
from glob import glob
import numpy as np
import matplotlib.pyplot as plt
class DataLoader():
def __init__(self, dataset_name, img_res=(256, 256)):
self.dataset_name = dataset_name
self.img_res = img_res
def Bayer4channel(self, bayer_x):
[w, h, channel] = bayer_x.shape
bayer_4x = np.zeros([w // 2, h // 2, 4])
for i in range(w // 2):
for j in range(h // 2):
bayer_4x[i, j, 0] = bayer_x[2 * i, 2 * j, 0]
bayer_4x[i, j, 1] = bayer_x[2 * i + 1, 2 * j, 0]
bayer_4x[i, j, 2] = bayer_x[2 * i, 2 * j + 1, 0]
bayer_4x[i, j, 3] = bayer_x[2 * i + 1, 2 * j + 1, 0]
return bayer_4x
def RGB2Bayer(self, RGB):
[w, h, channel] = RGB.shape
bayer = np.zeros([w, h, 1])
for i in range(w // 2):
for j in range(h // 2):
bayer[2 * i, 2 * j, 0] = RGB[2 * i, 2 * j, 0]
bayer[2 * i + 1, 2 * j, 0] = RGB[2 * i + 1, 2 * j, 1]
bayer[2 * i, 2 * j + 1, 0] = RGB[2 * i, 2 * j + 1, 1]
bayer[2 * i + 1, 2 * j + 1, 0] = RGB[2 * i + 1, 2 * j + 1, 2]
return bayer
def load_data(self, batch_size=1, idx=0, is_testing=False):
data_type = "train" if not is_testing else "test"
if is_testing == False:
path = glob('./datasets/%s/*' % (self.dataset_name))
if idx == 0:
np.random.shuffle(path)
batch_images = path[idx*batch_size:(idx+1)*batch_size]
imgs_hr = []
imgs_lr = []
for img_path in batch_images:
img = self.imread(img_path)
h, w = self.img_res
# low_h, low_w = int(h / 4), int(w / 4)
img_hr = scipy.misc.imresize(img, self.img_res)
# img_lr = scipy.misc.imresize(img, (low_h, low_w))
img_bayer = self.RGB2Bayer(img_hr)
img_bayer_4c = self.Bayer4channel(img_bayer)
img_lr = img_bayer_4c
# If training => do random flip
if not is_testing and np.random.random() < 0.5:
img_hr = np.fliplr(img_hr)
img_lr = np.fliplr(img_lr)
imgs_hr.append(img_hr)
imgs_lr.append(img_lr)
imgs_hr = np.array(imgs_hr) / 127.5 - 1.
imgs_lr = np.array(imgs_lr) / 127.5 - 1.
else:
path = glob('./crop_datasets/test_image/*')
batch_images = np.random.choice(path, size=batch_size)
imgs_hr = []
imgs_lr = []
for img_path in batch_images:
img = self.imread(img_path)
h, w = self.img_res
# low_h, low_w = int(h / 4), int(w / 4)
img_hr = scipy.misc.imresize(img, self.img_res)
# img_lr = scipy.misc.imresize(img, (low_h, low_w))
img_hr = scipy.misc.imresize(img, self.img_res)
# img_lr = scipy.misc.imresize(img, (low_h, low_w))
img_bayer = self.RGB2Bayer(img_hr)
img_bayer_4c = self.Bayer4channel(img_bayer)
img_lr = img_bayer_4c
# If training => do random flip
if not is_testing and np.random.random() < 0.5:
img_hr = np.fliplr(img_hr)
img_lr = np.fliplr(img_lr)
imgs_hr.append(img_hr)
imgs_lr.append(img_lr)
imgs_hr = np.array(imgs_hr) / 127.5 - 1.
imgs_lr = np.array(imgs_lr) / 127.5 - 1.
return imgs_hr, imgs_lr
def load_data_demo(self, batch_size=1, idx=0, is_testing=False, issampled=False):
data_type = "train" if not is_testing else "test"
if is_testing == False:
path = glob('./datasets/%s/*' % (self.dataset_name))
if idx == 0:
np.random.shuffle(path)
batch_images = path
if issampled==True:
batch_images = path[idx * batch_size:(idx + 1) * batch_size]
imgs_hr = []
imgs_lr = []
for img_path in batch_images:
img = self.imread(img_path)
h, w = self.img_res
# low_h, low_w = int(h / 4), int(w / 4)
img_hr = scipy.misc.imresize(img, self.img_res)
# img_lr = scipy.misc.imresize(img, (low_h, low_w))
img_bayer = self.RGB2Bayer(img_hr)
img_bayer_4c = self.Bayer4channel(img_bayer)
img_lr = img_bayer_4c
# If training => do random flip
if not is_testing and np.random.random() < 0.5:
img_hr = np.fliplr(img_hr)
img_lr = np.fliplr(img_lr)
imgs_hr.append(img_hr)
imgs_lr.append(img_lr)
imgs_hr = np.array(imgs_hr) / 127.5 - 1.
imgs_lr = np.array(imgs_lr) / 127.5 - 1.
else:
path = glob('./crop_datasets/test_image/*')
batch_images = path
imgs_hr = []
imgs_lr = []
for img_path in batch_images:
img = self.imread(img_path)
h, w = self.img_res
# low_h, low_w = int(h / 4), int(w / 4)
img_hr = scipy.misc.imresize(img, self.img_res)
# img_lr = scipy.misc.imresize(img, (low_h, low_w))
img_hr = scipy.misc.imresize(img, self.img_res)
# img_lr = scipy.misc.imresize(img, (low_h, low_w))
img_bayer = self.RGB2Bayer(img_hr)
img_bayer_4c = self.Bayer4channel(img_bayer)
img_lr = img_bayer_4c
# If training => do random flip
if not is_testing and np.random.random() < 0.5:
img_hr = np.fliplr(img_hr)
img_lr = np.fliplr(img_lr)
imgs_hr.append(img_hr)
imgs_lr.append(img_lr)
imgs_hr = np.array(imgs_hr) / 127.5 - 1.
imgs_lr = np.array(imgs_lr) / 127.5 - 1.
return imgs_hr, imgs_lr
def imread(self, path):
return scipy.misc.imread(path, mode='RGB').astype(np.float)
def count_batch_num(self):
path = glob('./datasets/%s/*' % (self.dataset_name))
num=len(path)
return num
def load_test_data(self):
path = glob('./crop_datasets/test_image/*')
batch_images = path
is_testing = True
imgs_hr = []
imgs_lr = []
for img_path in batch_images:
img = self.imread(img_path)
h, w = self.img_res
# low_h, low_w = int(h / 4), int(w / 4)
img_hr = scipy.misc.imresize(img, self.img_res)
# img_lr = scipy.misc.imresize(img, (low_h, low_w))
img_hr = scipy.misc.imresize(img, self.img_res)
# img_lr = scipy.misc.imresize(img, (low_h, low_w))
img_bayer = self.RGB2Bayer(img_hr)
img_bayer_4c = self.Bayer4channel(img_bayer)
img_lr = img_bayer_4c
# If training => do random flip
if not is_testing and np.random.random() < 0.5:
img_hr = np.fliplr(img_hr)
img_lr = np.fliplr(img_lr)
imgs_hr.append(img_hr)
imgs_lr.append(img_lr)
imgs_hr = np.array(imgs_hr) / 127.5 - 1.
imgs_lr = np.array(imgs_lr) / 127.5 - 1.
return imgs_hr, imgs_lr
def load_data_sr(self, batch_size=1, is_testing=False):
data_type = "train" if not is_testing else "test"
path = glob('./datasets/%s/*' % (self.dataset_name))
# batch_images = np.random.choice(path, size=batch_size)
batch_images = path
imgs_hr = []
imgs_lr = []
for img_path in batch_images:
img = self.imread(img_path)
h, w = self.img_res
low_h, low_w = int(h / 2), int(w / 2)
img_hr = scipy.misc.imresize(img, self.img_res)
img_lr = scipy.misc.imresize(img, (low_h, low_w))
# If training => do random flip
if not is_testing and np.random.random() < 0.5:
img_hr = np.fliplr(img_hr)
img_lr = np.fliplr(img_lr)
imgs_hr.append(img_hr)
imgs_lr.append(img_lr)
imgs_hr = np.array(imgs_hr) / 127.5 - 1.
imgs_lr = np.array(imgs_lr) / 127.5 - 1.
return imgs_hr, imgs_lr
def load_data_sr_test(self):
path = glob('./crop_datasets/test_image/*')
batch_images = path
is_testing = True
imgs_hr = []
imgs_lr = []
for img_path in batch_images:
img = self.imread(img_path)
h, w = self.img_res
low_h, low_w = int(h / 2), int(w / 2)
img_hr = scipy.misc.imresize(img, self.img_res)
img_lr = scipy.misc.imresize(img, (low_h, low_w))
# If training => do random flip
if not is_testing and np.random.random() < 0.5:
img_hr = np.fliplr(img_hr)
img_lr = np.fliplr(img_lr)
imgs_hr.append(img_hr)
imgs_lr.append(img_lr)
imgs_hr = np.array(imgs_hr) / 127.5 - 1.
imgs_lr = np.array(imgs_lr) / 127.5 - 1.
return imgs_hr, imgs_lr
def load_test_data_crop(self):
path = glob('./crop_datasets/test_image/*')
batch_images = path
is_testing = True
imgs_hr = []
imgs_lr = []
for img_path in batch_images:
img = self.imread(img_path)
h, w = self.img_res
# low_h, low_w = int(h / 4), int(w / 4)
img_hr = scipy.misc.imresize(img, self.img_res)
# img_lr = scipy.misc.imresize(img, (low_h, low_w))
img_hr = scipy.misc.imresize(img, self.img_res)
# img_lr = scipy.misc.imresize(img, (low_h, low_w))
img_bayer = self.RGB2Bayer(img_hr)
img_bayer_4c = self.Bayer4channel(img_bayer)
img_lr = img_bayer_4c
# If training => do random flip
if not is_testing and np.random.random() < 0.5:
img_hr = np.fliplr(img_hr)
img_lr = np.fliplr(img_lr)
imgs_hr.append(img_hr)
imgs_lr.append(img_lr)
imgs_hr = np.array(imgs_hr) / 127.5 - 1.
imgs_lr = np.array(imgs_lr) / 127.5 - 1.
return imgs_hr, imgs_lr