tensorflow keras deblurGAN复现

代码

import numpy as np
import matplotlib
from matplotlib import pyplot as plt

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers,optimizers,losses
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.python.keras import backend as K
from tensorflow.keras.utils import plot_model
from IPython.display import Image
from tensorflow.keras.applications.vgg16 import VGG16

import cv2
import PIL
import json, os
import sys

import labelme
import labelme.utils as utils
import glob
import itertools

class DebulgGan():
    def __init__(self):
        self.image_shape = (256,256,3)
        self.ngf = 64
        self.ndf = 64
        self.input_nc = 3
        self.output_nc = 3
        self.input_shape_generator = (256, 256, 3)
        self.n_blocks_gen = 9
        self.epochs = 100
        self.batch_size = 5
        self.train_number = 20000
        self.blur_path = r'F:\BaiduNetdiskDownload\deblugData\train\x'
        self.sharp_path = r'F:\BaiduNetdiskDownload\deblugData\train\y'
        self.img_savepath =r'C:\Users\Administrator\Desktop\photo'
        self.model_path = r'C:\Users\Administrator\Desktop\photo\deblurGAN.h5'
        # define Net
        self.generator = self.generator_model()
        self.discriminator = self.discriminator_model()
        self.model = self.generator_containing_discriminator_multiple_outputs()
        self.loss_model = self.bulid_loss_model()
    def res_block(self,input, filters, kernel_size=(3,3), strides=(1,1), use_dropout=False):

        x = layers.Conv2D(filters=filters,kernel_size=kernel_size,strides=strides,padding='same')(input)
        x = layers.BatchNormalization()(x)
        x = layers.Activation('relu')(x)

        if use_dropout:
            x = layers.Dropout(0.5)(x)

        x = layers.Conv2D(filters=filters,kernel_size=kernel_size,strides=strides,padding='same')(x)
        x = layers.BatchNormalization()(x)

        # 输入和输出之间连接两个卷积层
        merged = layers.Add()([input, x])
        return merged

    def generator_model(self):
        """构建生成模型"""
        # Current version : ResNet block
        inputs = keras.Input(shape=self.image_shape)


        x = layers.Conv2D(filters=self.ngf, kernel_size=(7, 7), padding='same')(inputs)
        x = layers.BatchNormalization()(x)
        x = layers.Activation('relu')(x)

        # Increase filter number
        n_downsampling = 2
        for i in range(n_downsampling):
            mult = 2 ** i
            x = layers.Conv2D(filters=self.ngf * mult * 2, kernel_size=(3, 3), strides=2, padding='same')(x)
            x = layers.BatchNormalization()(x)
            x = layers.Activation('relu')(x)

        # 9 ResNet blocks
        mult = 2 ** n_downsampling
        for i in range(self.n_blocks_gen):
            x = self.res_block(x, self.ngf * mult, use_dropout=True)

        # 减少卷积核到3个 (RGB)
        for i in range(n_downsampling):
            mult = 2 ** (n_downsampling - i)
            x = layers.Conv2DTranspose(filters=int(self.ngf * mult / 2), kernel_size=(3, 3), strides=2, padding='same')(x)
            x = layers.BatchNormalization()(x)
            x = layers.Activation('relu')(x)

        x = layers.Conv2D(filters=self.output_nc, kernel_size=(7, 7), padding='same')(x)
        x = layers.Activation('tanh')(x)

        # Add direct connection from input to output and recenter to [-1, 1]
        outputs = layers.Add()([x, inputs])
        outputs = layers.Lambda(lambda z: z / 2)(outputs)

        model = keras.Model(inputs=inputs, outputs=outputs, name='Generator')
        return model

    def discriminator_model(self):
        """构建判别模型."""
        n_layers, use_sigmoid = 3, False
        inputs = keras.Input(shape=self.image_shape)

        x = layers.Conv2D(filters=self.ndf, kernel_size=(4, 4), strides=2, padding='same')(inputs)
        x = layers.LeakyReLU(0.2)(x)

        nf_mult, nf_mult_prev = 1, 1
        for n in range(n_layers):
            nf_mult_prev, nf_mult = nf_mult, min(2 ** n, 8)
            x = layers.Conv2D(filters=self.ndf * nf_mult, kernel_size=(4, 4), strides=2, padding='same')(x)
            x = layers.BatchNormalization()(x)
            x = layers.LeakyReLU(0.2)(x)

        nf_mult_prev, nf_mult = nf_mult, min(2 ** n_layers, 8)
        x = layers.Conv2D(filters=self.ndf * nf_mult, kernel_size=(4, 4), strides=1, padding='same')(x)
        x = layers.BatchNormalization()(x)
        x = layers.LeakyReLU(0.2)(x)

        x = layers.Conv2D(filters=1, kernel_size=(4, 4), strides=1, padding='same')(x)
        if use_sigmoid:
            x = layers.Activation('sigmoid')(x)

        x = layers.Flatten()(x)
        x = layers.Dense(1024, activation='tanh')(x)
        x = layers.Dense(1, activation='sigmoid')(x)

        model = keras.Model(inputs=inputs, outputs=x, name='Discriminator')
        return model

    def generator_containing_discriminator_multiple_outputs(self):
        inputs = keras.Input(shape=self.image_shape)
        generated_images = self.generator(inputs)
        outputs = self.discriminator(generated_images)
        model = keras.Model(inputs=inputs, outputs=[generated_images, outputs])
        return model
    def bulid_loss_model(self):#define loss model
        vgg = VGG16(include_top=False, weights='imagenet', input_shape=self.image_shape)
        loss_model = keras.Model(inputs=vgg.input, outputs=vgg.get_layer('block3_conv3').output)
        loss_model.trainable = False
        return loss_model

    def perceptual_loss(self,y_true, y_pred):#perceptual loss for discriminator_loss
        return tf.reduce_mean(K.square(self.loss_model(y_true) - self.loss_model(y_pred)))
    def wasserstein_loss(self,y_true, y_pred):#wasserstein loss for generator_loss
        return tf.reduce_mean(y_true * y_pred)

    def compile(self):
        # self.discriminator.trainable = True
        self.discriminator.compile(optimizer=keras.optimizers.Adam(lr=1E-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08),
                                   loss=self.wasserstein_loss)
        self.discriminator.trainable = False
        self.model.compile(optimizer=keras.optimizers.Adam(lr=1E-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08),
                           loss=[self.perceptual_loss, self.wasserstein_loss],
                           loss_weights=[100, 1])
        # self.discriminator.trainable = True

    def load_data(self,blur_imgs,sharp_imgs,trian_idx,step):
        blur_img = []
        sharp_img = []

        for j in range(self.batch_size):
            idx = trian_idx[step * self.batch_size + j]
            img = (cv2.imread(blur_imgs[idx], 1)-127.5)/255
            label = (cv2.imread(sharp_imgs[idx], 1)-127.5)/255
            blur_img.append(img)
            sharp_img.append(label)

        return np.array(blur_img),np.array(sharp_img)

    def train(self):
        self.compile()
        self.model.summary()
        #---------------------------------------------------load image location
        blur_location = glob.glob(self.blur_path + '/*.png')
        blur_location.sort()
        sharp_location = glob.glob(self.sharp_path + '/*.png')
        sharp_location.sort()
        train_idx = np.arange(0, self.train_number, 1)
        steps = int(self.train_number/self.batch_size)
        output_true_batch, output_false_batch = np.ones((self.batch_size, 1)), -np.ones((self.batch_size, 1))
        #---------------------------------------------------
        for epoch in range(self.epochs):
            train_idx = (tf.random.shuffle(train_idx)).numpy()# index shuffle

            for step in range(steps):
                blur_imgs ,sharp_imgs = self.load_data(blur_location,sharp_location,train_idx,step)#read img_batch
                gan_imgs = self.generator.predict(blur_imgs)

                d_loss_real = self.discriminator.train_on_batch(sharp_imgs, output_true_batch)
                d_loss_fake = self.discriminator.train_on_batch(gan_imgs, output_false_batch)
                discriminator_loss = 0.5 * np.add(d_loss_fake, d_loss_real)

                self.discriminator.trainable = False
                # Train generator only on discriminator's decision and generated images
                generator_loss = self.model.train_on_batch(blur_imgs, [sharp_imgs, output_true_batch])

                print("epoch:%d step:%d [discriminator_loss: %f] [generator_loss: %f]" % (
                    epoch, step, discriminator_loss, generator_loss[0]))
                if step%500 ==0:
                    self.generate_sample_images(gan_imgs,sharp_imgs,epoch,step)
            self.model.save(self.model_path)  # 每个epoch存储模型
            print('save model')
    def generate_sample_images(self, gan_imgs,sharp_imgs,epoch,step):


        idx = 0
        blur =((gan_imgs[idx]+1)*127.5-0.0001).astype(np.uint8)
        sharp = ((sharp_imgs[idx]+1)*127.5-0.0001).astype(np.uint8)
        print((self.img_savepath + "/%d.%d_blur.png" % (epoch,step)))
        cv2.imwrite((self.img_savepath + "/%d.%d_blur.png" % (epoch,step)),blur)
        cv2.imwrite((self.img_savepath + "/%d.%d_sharp.png" % (epoch, step)),sharp)

        print('save plot')
deblurGAN = DebulgGan()
deblurGAN.train()
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值