代码
import numpy as np
import matplotlib
from matplotlib import pyplot as plt
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers,optimizers,losses
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.python.keras import backend as K
from tensorflow.keras.utils import plot_model
from IPython.display import Image
from tensorflow.keras.applications.vgg16 import VGG16
import cv2
import PIL
import json, os
import sys
import labelme
import labelme.utils as utils
import glob
import itertools
class DebulgGan():
def __init__(self):
self.image_shape = (256,256,3)
self.ngf = 64
self.ndf = 64
self.input_nc = 3
self.output_nc = 3
self.input_shape_generator = (256, 256, 3)
self.n_blocks_gen = 9
self.epochs = 100
self.batch_size = 5
self.train_number = 20000
self.blur_path = r'F:\BaiduNetdiskDownload\deblugData\train\x'
self.sharp_path = r'F:\BaiduNetdiskDownload\deblugData\train\y'
self.img_savepath =r'C:\Users\Administrator\Desktop\photo'
self.model_path = r'C:\Users\Administrator\Desktop\photo\deblurGAN.h5'
self.generator = self.generator_model()
self.discriminator = self.discriminator_model()
self.model = self.generator_containing_discriminator_multiple_outputs()
self.loss_model = self.bulid_loss_model()
def res_block(self,input, filters, kernel_size=(3,3), strides=(1,1), use_dropout=False):
x = layers.Conv2D(filters=filters,kernel_size=kernel_size,strides=strides,padding='same')(input)
x = layers.BatchNormalization()(x)
x = layers.Activation('relu')(x)
if use_dropout:
x = layers.Dropout(0.5)(x)
x = layers.Conv2D(filters=filters,kernel_size=kernel_size,strides=strides,padding='same')(x)
x = layers.BatchNormalization()(x)
merged = layers.Add()([input, x])
return merged
def generator_model(self):
"""构建生成模型"""
inputs = keras.Input(shape=self.image_shape)
x = layers.Conv2D(filters=self.ngf, kernel_size=(7, 7), padding='same')(inputs)
x = layers.BatchNormalization()(x)
x = layers.Activation('relu')(x)
n_downsampling = 2
for i in range(n_downsampling):
mult = 2 ** i
x = layers.Conv2D(filters=self.ngf * mult * 2, kernel_size=(3, 3), strides=2, padding='same')(x)
x = layers.BatchNormalization()(x)
x = layers.Activation('relu')(x)
mult = 2 ** n_downsampling
for i in range(self.n_blocks_gen):
x = self.res_block(x, self.ngf * mult, use_dropout=True)
for i in range(n_downsampling):
mult = 2 ** (n_downsampling - i)
x = layers.Conv2DTranspose(filters=int(self.ngf * mult / 2), kernel_size=(3, 3), strides=2, padding='same')(x)
x = layers.BatchNormalization()(x)
x = layers.Activation('relu')(x)
x = layers.Conv2D(filters=self.output_nc, kernel_size=(7, 7), padding='same')(x)
x = layers.Activation('tanh')(x)
outputs = layers.Add()([x, inputs])
outputs = layers.Lambda(lambda z: z / 2)(outputs)
model = keras.Model(inputs=inputs, outputs=outputs, name='Generator')
return model
def discriminator_model(self):
"""构建判别模型."""
n_layers, use_sigmoid = 3, False
inputs = keras.Input(shape=self.image_shape)
x = layers.Conv2D(filters=self.ndf, kernel_size=(4, 4), strides=2, padding='same')(inputs)
x = layers.LeakyReLU(0.2)(x)
nf_mult, nf_mult_prev = 1, 1
for n in range(n_layers):
nf_mult_prev, nf_mult = nf_mult, min(2 ** n, 8)
x = layers.Conv2D(filters=self.ndf * nf_mult, kernel_size=(4, 4), strides=2, padding='same')(x)
x = layers.BatchNormalization()(x)
x = layers.LeakyReLU(0.2)(x)
nf_mult_prev, nf_mult = nf_mult, min(2 ** n_layers, 8)
x = layers.Conv2D(filters=self.ndf * nf_mult, kernel_size=(4, 4), strides=1, padding='same')(x)
x = layers.BatchNormalization()(x)
x = layers.LeakyReLU(0.2)(x)
x = layers.Conv2D(filters=1, kernel_size=(4, 4), strides=1, padding='same')(x)
if use_sigmoid:
x = layers.Activation('sigmoid')(x)
x = layers.Flatten()(x)
x = layers.Dense(1024, activation='tanh')(x)
x = layers.Dense(1, activation='sigmoid')(x)
model = keras.Model(inputs=inputs, outputs=x, name='Discriminator')
return model
def generator_containing_discriminator_multiple_outputs(self):
inputs = keras.Input(shape=self.image_shape)
generated_images = self.generator(inputs)
outputs = self.discriminator(generated_images)
model = keras.Model(inputs=inputs, outputs=[generated_images, outputs])
return model
def bulid_loss_model(self):
vgg = VGG16(include_top=False, weights='imagenet', input_shape=self.image_shape)
loss_model = keras.Model(inputs=vgg.input, outputs=vgg.get_layer('block3_conv3').output)
loss_model.trainable = False
return loss_model
def perceptual_loss(self,y_true, y_pred):
return tf.reduce_mean(K.square(self.loss_model(y_true) - self.loss_model(y_pred)))
def wasserstein_loss(self,y_true, y_pred):
return tf.reduce_mean(y_true * y_pred)
def compile(self):
self.discriminator.compile(optimizer=keras.optimizers.Adam(lr=1E-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08),
loss=self.wasserstein_loss)
self.discriminator.trainable = False
self.model.compile(optimizer=keras.optimizers.Adam(lr=1E-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08),
loss=[self.perceptual_loss, self.wasserstein_loss],
loss_weights=[100, 1])
def load_data(self,blur_imgs,sharp_imgs,trian_idx,step):
blur_img = []
sharp_img = []
for j in range(self.batch_size):
idx = trian_idx[step * self.batch_size + j]
img = (cv2.imread(blur_imgs[idx], 1)-127.5)/255
label = (cv2.imread(sharp_imgs[idx], 1)-127.5)/255
blur_img.append(img)
sharp_img.append(label)
return np.array(blur_img),np.array(sharp_img)
def train(self):
self.compile()
self.model.summary()
blur_location = glob.glob(self.blur_path + '/*.png')
blur_location.sort()
sharp_location = glob.glob(self.sharp_path + '/*.png')
sharp_location.sort()
train_idx = np.arange(0, self.train_number, 1)
steps = int(self.train_number/self.batch_size)
output_true_batch, output_false_batch = np.ones((self.batch_size, 1)), -np.ones((self.batch_size, 1))
for epoch in range(self.epochs):
train_idx = (tf.random.shuffle(train_idx)).numpy()
for step in range(steps):
blur_imgs ,sharp_imgs = self.load_data(blur_location,sharp_location,train_idx,step)
gan_imgs = self.generator.predict(blur_imgs)
d_loss_real = self.discriminator.train_on_batch(sharp_imgs, output_true_batch)
d_loss_fake = self.discriminator.train_on_batch(gan_imgs, output_false_batch)
discriminator_loss = 0.5 * np.add(d_loss_fake, d_loss_real)
self.discriminator.trainable = False
generator_loss = self.model.train_on_batch(blur_imgs, [sharp_imgs, output_true_batch])
print("epoch:%d step:%d [discriminator_loss: %f] [generator_loss: %f]" % (
epoch, step, discriminator_loss, generator_loss[0]))
if step%500 ==0:
self.generate_sample_images(gan_imgs,sharp_imgs,epoch,step)
self.model.save(self.model_path)
print('save model')
def generate_sample_images(self, gan_imgs,sharp_imgs,epoch,step):
idx = 0
blur =((gan_imgs[idx]+1)*127.5-0.0001).astype(np.uint8)
sharp = ((sharp_imgs[idx]+1)*127.5-0.0001).astype(np.uint8)
print((self.img_savepath + "/%d.%d_blur.png" % (epoch,step)))
cv2.imwrite((self.img_savepath + "/%d.%d_blur.png" % (epoch,step)),blur)
cv2.imwrite((self.img_savepath + "/%d.%d_sharp.png" % (epoch, step)),sharp)
print('save plot')
deblurGAN = DebulgGan()
deblurGAN.train()