tensorflow实现deblurgan_v1

参考链接:

https://blog.csdn.net/z704630835/article/details/84135277

https://github.com/dongheehand/DeblurGAN-tf

数据集:

代码实现:

data_loader.py

import tensorflow as tf
import os


class dataloader():

    def __init__(self, args):

        self.channel = 3

        self.mode = args.mode
        self.patch_size = args.patch_size
        self.batch_size = args.batch_size
        self.train_Sharp_path = args.train_Sharp_path
        self.train_Blur_path = args.train_Blur_path
        self.test_Sharp_path = args.test_Sharp_path
        self.test_Blur_path = args.test_Blur_path
        self.test_with_train = args.test_with_train
        self.test_batch = args.test_batch
        self.load_X = args.load_X
        self.load_Y = args.load_Y
        self.augmentation = args.augmentation

    def build_loader(self):

        if self.mode == 'train':

            tr_sharp_imgs = sorted(os.listdir(self.train_Sharp_path))
            tr_blur_imgs = sorted(os.listdir(self.train_Blur_path))
            tr_sharp_imgs = [os.path.join(self.train_Sharp_path, ele) for ele in tr_sharp_imgs]
            tr_blur_imgs = [os.path.join(self.train_Blur_path, ele) for ele in tr_blur_imgs]
            train_list = (tr_blur_imgs, tr_sharp_imgs)

            self.tr_dataset = tf.data.Dataset.from_tensor_slices(train_list)
            self.tr_dataset = self.tr_dataset.map(self._parse, num_parallel_calls=4).prefetch(32)
            self.tr_dataset = self.tr_dataset.map(self._resize, num_parallel_calls=4).prefetch(32)
            self.tr_dataset = self.tr_dataset.map(self._get_patch, num_parallel_calls=4).prefetch(32)
            if self.augmentation:
                self.tr_dataset = self.tr_dataset.map(self._data_augmentation, num_parallel_calls=4).prefetch(32)
            self.tr_dataset = self.tr_dataset.shuffle(32)
            self.tr_dataset = self.tr_dataset.repeat()
            self.tr_dataset = self.tr_dataset.batch(self.batch_size)

            if self.test_with_train:
                val_sharp_imgs = sorted(os.listdir(self.test_Sharp_path))
                val_blur_imgs = sorted(os.listdir(self.test_Blur_path))
                val_sharp_imgs = [os.path.join(self.test_Sharp_path, ele) for ele in val_sharp_imgs]
                val_blur_imgs = [os.path.join(self.test_Blur_path, ele) for ele in val_blur_imgs]
                valid_list = (val_blur_imgs, val_sharp_imgs)

                self.val_dataset = tf.data.Dataset.from_tensor_slices(valid_list)
                self.val_dataset = self.val_dataset.map(self._parse, num_parallel_calls=4).prefetch(32)
                self.val_dataset = self.val_dataset.batch(self.test_batch)

            iterator = tf.data.Iterator.from_structure(self.tr_dataset.output_types, self.tr_dataset.output_shapes)
            self.next_batch = iterator.get_next()
            self.init_op = {}
            self.init_op['tr_init'] = iterator.make_initializer(self.tr_dataset)

            if self.test_with_train:
                self.init_op['val_init'] = iterator.make_initializer(self.val_dataset)

        elif self.mode == 'test':

            val_sharp_imgs = sorted(os.listdir(self.test_Sharp_path))
            val_blur_imgs = sorted(os.listdir(self.test_Blur_path))
            val_sharp_imgs = [os.path.join(self.test_Sharp_path, ele) for ele in val_sharp_imgs]
            val_blur_imgs = [os.path.join(self.test_Blur_path, ele) for ele in val_blur_imgs]
            valid_list = (val_blur_imgs, val_sharp_imgs)

            self.val_dataset = tf.data.Dataset.from_tensor_slices(valid_list)
            self.val_dataset = self.val_dataset.map(self._parse, num_parallel_calls=4).prefetch(32)
            self.val_dataset = self.val_dataset.batch(1)

            iterator = tf.data.Iterator.from_structure(self.val_dataset.output_types, self.val_dataset.output_shapes)
            self.next_batch = iterator.get_next()
            self.init_op = {}
            self.init_op['val_init'] = iterator.make_initializer(self.val_dataset)

    def _parse(self, image_blur, image_sharp):

        image_blur = tf.read_file(image_blur)
        image_sharp = tf.read_file(image_sharp)

        image_blur = tf.image.decode_image(image_blur, channels=self.channel)
        image_sharp = tf.image.decode_image(image_sharp, channels=self.channel)

        image_blur = tf.cast(image_blur, tf.float32)
        image_sharp = tf.cast(image_sharp, tf.float32)

        return image_blur, image_sharp

    def _resize(self, image_blur, image_sharp):

        image_blur = tf.image.resize_images(image_blur, (self.load_Y, self.load_X), tf.image.ResizeMethod.BICUBIC)
        image_sharp = tf.image.resize_images(image_sharp, (self.load_Y, self.load_X), tf.image.ResizeMethod.BICUBIC)

        return image_blur, image_sharp

    def _parse_Blur_only(self, image_blur):

        image_blur = tf.read_file(image_blur)
        image_blur = tf.image.decode_image(image_blur, channels=self.channel)
        image_blur = tf.cast(image_blur, tf.float32)

        return image_blur

    def _get_patch(self, image_blur, image_sharp):

        shape = tf.shape(image_blur)
        ih = shape[0]
        iw = shape[1]

        ix = tf.random_uniform(shape=[1], minval=0, maxval=iw - self.patch_size + 1, dtype=tf.int32)[0]
        iy = tf.random_uniform(shape=[1], minval=0, maxval=ih - self.patch_size + 1, dtype=tf.int32)[0]

        img_sharp_in = image_sharp[iy:iy + self.patch_size, ix:ix + self.patch_size]
        img_blur_in = image_blur[iy:iy + self.patch_size, ix:ix + self.patch_size]

        return img_blur_in, img_sharp_in

    def _data_augmentation(self, image_blur, image_sharp):

        rot = tf.random_uniform(shape=[1], minval=0, maxval=3, dtype=tf.int32)[0]
        flip_rl = tf.random_uniform(shape=[1], minval=0, maxval=3, dtype=tf.int32)[0]
        flip_updown = tf.random_uniform(shape=[1], minval=0, maxval=3, dtype=tf.int32)[0]

        image_blur = tf.image.rot90(image_blur, rot)
        image_sharp = tf.image.rot90(image_sharp, rot)

        rl = tf.equal(tf.mod(flip_rl, 2), 0)
        ud = tf.equal(tf.mod(flip_updown, 2), 0)

        image_blur = tf.cond(rl, true_fn=lambda: tf.image.flip_left_right(image_blur),
                             false_fn=lambda: image_blur)
        image_sharp = tf.cond(rl, true_fn=lambda: tf.image.flip_left_right(image_sharp),
                              false_fn=lambda: image_sharp)

        image_blur = tf.cond(ud, true_fn=lambda: tf.image.flip_up_down(image_blur),
                             false_fn=lambda: image_blur)
        image_sharp = tf.cond(ud, true_fn=lambda: tf.image.flip_up_down(image_sharp),
                              false_fn=lambda: image_sharp)

        return image_blur, image_sharp

main.py

import tensorflow as tf
from DeblurGAN import DeblurGAN
from mode import *
import argparse

parser = argparse.ArgumentParser()


def str2bool(v):
    return v.lower() in ('true')


## Model specification
parser.add_argument("--n_feats", type=int, default=64)
parser.add_argument("--num_of_down_scale", type=int, default=2)
parser.add_argument("--gen_resblocks", type=int, default=9)
parser.add_argument("--discrim_blocks", type=int, default=3)

## Data specification
parser.add_argument("--train_Sharp_path", type=str, default="data/train/sharp/")
parser.add_argument("--train_Blur_path", type=str, default="data/train/blur/")
parser.add_argument("--test_Sharp_path", type=str, default="data/test/sharp/")
parser.add_argument("--test_Blur_path", type=str, default="data/test/blur/")
parser.add_argument("--vgg_path", type=str, default="C:/Users/ywx613838/.keras/models/vgg19.npy")
parser.add_argument("--patch_size", type=int, default=256) #training size
parser.add_argument("--result_path", type=str, default="./result")
parser.add_argument("--model_path", type=str, default="./model")

## Optimization
parser.add_argument("--batch_size", type=int, default=2) #1 training-For-4
parser.add_argument("--max_epoch", type=int, default=2)
parser.add_argument("--learning_rate", type=float, default=1e-4)
parser.add_argument("--decay_step", type=int, default=150)
parser.add_argument("--test_with_train", type=str2bool, default=False)
parser.add_argument("--save_test_result", type=str2bool, default=True)

## Training or test specification
parser.add_argument("--mode", type=str, default="train")
parser.add_argument("--critic_updates", type=int, default=5)
parser.add_argument("--augmentation", type=str2bool, default=False)
parser.add_argument("--load_X", type=int, default=640)
parser.add_argument("--load_Y", type=int, default=360)
parser.add_argument("--fine_tuning", type=str2bool, default=False)
parser.add_argument("--log_freq", type=int, default=1)
parser.add_argument("--model_save_freq", type=int, default=1)
parser.add_argument("--pre_trained_model", type=str, default="./model/")
parser.add_argument("--test_batch", type=int, default=2)
args = parser.parse_args()

model = DeblurGAN(args)
model.build_graph()

print("Build DeblurGAN model!")

config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver(max_to_keep=None)

if args.mode == 'train':
    train(args, model, sess, saver)

elif args.mode == 'test':
    f = open("test_results.txt", 'w')
    test(args, model, sess, saver, f, step=-1, loading=True)
    f.close()


mode.py

import os
import tensorflow as tf
from PIL import Image
import numpy as np
import time
import util2


def train(args, model, sess, saver):
    if args.fine_tuning:
        saver.restore(sess, args.pre_trained_model)
        print("saved model is loaded for fine-tuning!")
        print("model path is %s" % args.pre_trained_model)

    num_imgs = len(os.listdir(args.train_Sharp_path))
    merged = tf.summary.merge_all()
    train_writer = tf.summary.FileWriter('./logs', sess.graph)
    if args.test_with_train:
        f = open("valid_logs.txt", 'w')

    epoch = 0
    step = num_imgs // args.batch_size

    blur_imgs = util2.image_loader(args.train_Blur_path, args.load_X, args.load_Y)
    sharp_imgs = util2.image_loader(args.train_Sharp_path, args.load_X, args.load_Y)

    while epoch < args.max_epoch:
        random_index = np.random.permutation(len(blur_imgs))
        for k in range(step):
            s_time = time.time()
            blur_batch, sharp_batch = util2.batch_gen(blur_imgs, sharp_imgs, args.patch_size,
                                                      args.batch_size, random_index, k)
            
            for t in range(args.critic_updates):
                _, D_loss = sess.run([model.D_train, model.D_loss],
                                     feed_dict={model.blur: blur_batch, model.sharp: sharp_batch, model.epoch: epoch})

            _, G_loss = sess.run([model.G_train, model.G_loss],
                                 feed_dict={model.blur: blur_batch, model.sharp: sharp_batch, model.epoch: epoch})

            e_time = time.time()

        if epoch % args.log_freq == 0:
            summary = sess.run(merged, feed_dict={model.blur: blur_batch, model.sharp: sharp_batch})
            train_writer.add_summary(summary, epoch)
            if args.test_with_train:
                test(args, model, sess, saver, f, epoch, loading=False)
            print("%d training epoch completed" % epoch)
            print("D_loss : {}, \t G_loss : {}".format(D_loss, G_loss))
            print("Elpased time : %0.4f" % (e_time - s_time))
            # print("D_loss : %0.4f, \t G_loss : %0.4f" % (D_loss, G_loss))
            # print("Elpased time : %0.4f" % (e_time - s_time))
        if (epoch) % args.model_save_freq == 0:
            saver.save(sess, './model/DeblurrGAN', global_step=epoch, write_meta_graph=True)

        epoch += 1

    saver.save(sess, './model/DeblurrGAN_last', write_meta_graph=True)

    if args.test_with_train:
        f.close()


def test(args, model, sess, saver, file, step=-1, loading=False):
    if loading:

        import re
        print(" [*] Reading checkpoints...")
        ckpt = tf.train.get_checkpoint_state(args.pre_trained_model)
        if ckpt and ckpt.model_checkpoint_path:
            ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
            saver.restore(sess, os.path.join(args.pre_trained_model, ckpt_name))
            print(" [*] Success to read {}".format(ckpt_name))
        else:
            print(" [*] Failed to find a checkpoint")

    blur_img_name = sorted(os.listdir(args.test_Blur_path))
    sharp_img_name = sorted(os.listdir(args.test_Sharp_path))

    PSNR_list = []
    ssim_list = []

    blur_imgs = util2.image_loader(args.test_Blur_path, args.load_X, args.load_Y, is_train=False)
    sharp_imgs = util2.image_loader(args.test_Sharp_path, args.load_X, args.load_Y, is_train=False)

    if not os.path.exists('./result/'):
        os.makedirs('./result/')

    for i, ele in enumerate(blur_imgs):
        blur = np.expand_dims(ele, axis=0)
        sharp = np.expand_dims(sharp_imgs[i], axis=0)
        # blur = np.transpose(blur, [0, 2, 3, 1])
        # sharp = np.transpose(sharp, [0, 2, 3, 1])
        output, psnr, ssim = sess.run([model.output, model.PSNR, model.ssim],
                                      feed_dict={model.blur: blur, model.sharp: sharp})
        
        if args.save_test_result:
            #逆归一化
            # output = np.array((output[0] + 1) / 2.0 * 255.0).astype(np.uint8)
            output = Image.fromarray(output[0])
            split_name = blur_img_name[i].split('.')
            output.save(os.path.join(args.result_path, '%s_deblur.png' % (''.join(map(str, split_name[:-1])))))

        PSNR_list.append(psnr)
        ssim_list.append(ssim)

    length = len(PSNR_list)

    mean_PSNR = sum(PSNR_list) / length
    mean_ssim = sum(ssim_list) / length

    if step == -1:
        file.write('PSNR : {} SSIM : {}'.format(mean_PSNR, mean_ssim))
        file.close()

    else:
        file.write("{}d-epoch step PSNR : {} SSIM : {} \n".format(step, mean_PSNR, mean_ssim))

toPB.py

import tensorflow as tf
import os
from tensorflow.python.framework import graph_util
from tensorflow.python import pywrap_tensorflow



def freeze_graph(input_checkpoint, output_graph):
    '''
    :param input_checkpoint:
    :param output_graph: PB模型保存路径
    :return:
    '''
    # checkpoint = tf.train.get_checkpoint_state(model_folder) #检查目录下ckpt文件状态是否可用
    # input_checkpoint = checkpoint.model_checkpoint_path #得ckpt文件路径

    # 指定输出的节点名称,该节点名称必须是原模型中存在的节点
    output_node_names = "generator/clip"
    saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)

    with tf.Session() as sess:
        saver.restore(sess, input_checkpoint)  # 恢复图并得到数据
        output_graph_def = graph_util.convert_variables_to_constants(  # 模型持久化,将变量值固定
            sess=sess,
            input_graph_def=sess.graph_def,  # 等于:sess.graph_def
            output_node_names=output_node_names.split(","))  # 如果有多个输出节点,以逗号隔开

        with tf.gfile.GFile(output_graph, "wb") as f:  # 保存模型
            f.write(output_graph_def.SerializeToString())  # 序列化输出
        print("%d ops in the final graph." % len(output_graph_def.node))  # 得到当前图有几个操作节点

if __name__ == '__main__':
    input_checkpoint = './'
    # # 查看节点名称:
    # reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint)
    # var_to_shape_map = reader.get_variable_to_shape_map()
    # for key in var_to_shape_map:
    #     print("tensor_name: ", key)
    # 输出pb模型的路径
    out_pb_path = "./frozen_model.pb"
    # 调用freeze_graph将ckpt转为pb

    freeze_graph(input_checkpoint, out_pb_path)

test_pb.py

import tensorflow as tf
import os
import cv2
from tensorflow.python.framework import graph_util
import numpy as np
from PIL import Image

def get_RAC(file_path, save_path, model_path):
    with tf.Graph().as_default():
        output_graph_def = tf.GraphDef()

        with open(model_path, "rb") as f:
            output_graph_def.ParseFromString(f.read())
            _ = tf.import_graph_def(output_graph_def, name="")

        with tf.Session() as sess:
            init = tf.global_variables_initializer()
            sess.run(init)
            
            # 图片加载
            img=Image.open(file_path).convert('RGB')
            img=np.array(img)
            img=np.expand_dims(img, axis=0)

            input_x = sess.graph.get_tensor_by_name("blur:0")
            output = sess.graph.get_tensor_by_name("generator/clip:0")
            print("output:",output)

            output_image = sess.run(output, feed_dict={input_x: img})[0]
            print("output_img:",output_image.shape)
            output_image=(output_image + 1.0) * 255.0 / 2.0
            output_image = np.array(output_image).astype(np.uint8)
            output_image = Image.fromarray(output_image)

            output_image.save(save_path)
            



#模型路径
model_path = 'model/0000/test.pb'
#用于测试的原图片路径---批量
path1='data/test/blur/'
#测试生成的图片保存路径
path2 = 'data/result/'

#遍历原图片,单张生成测试图片
for file in os.listdir(path1):
    file_path=path1+file
    save_path=path2+file
    get_RAC(file_path, save_path, model_path)

util.py

from PIL import Image
import numpy as np
import random
import os


def image_loader(image_path, load_x, load_y, is_train=True):
    imgs = sorted(os.listdir(image_path))
    img_list = []
    for ele in imgs:
        img = Image.open(os.path.join(image_path, ele))
        if is_train:
            img = img.resize((load_x, load_y), Image.BICUBIC)
        img_list.append(np.array(img))

    return img_list


def data_augument(lr_img, hr_img, aug):
    if aug < 4:
        lr_img = np.rot90(lr_img, aug)
        hr_img = np.rot90(hr_img, aug)

    elif aug == 4:
        lr_img = np.fliplr(lr_img)
        hr_img = np.fliplr(hr_img)

    elif aug == 5:
        lr_img = np.flipud(lr_img)
        hr_img = np.flipud(hr_img)

    elif aug == 6:
        lr_img = np.rot90(np.fliplr(lr_img))
        hr_img = np.rot90(np.fliplr(hr_img))

    elif aug == 7:
        lr_img = np.rot90(np.flipud(lr_img))
        hr_img = np.rot90(np.flipud(hr_img))

    return lr_img, hr_img


def batch_gen(blur_imgs, sharp_imgs, patch_size, batch_size, random_index, step, augment=False):
    img_index = random_index[step * batch_size: (step + 1) * batch_size]

    all_img_blur = []
    all_img_sharp = []

    for _index in img_index:
        all_img_blur.append(blur_imgs[_index])
        all_img_sharp.append(sharp_imgs[_index])

    blur_batch = []
    sharp_batch = []

    for i in range(len(all_img_blur)):

        ih, iw, _ = all_img_blur[i].shape
        ix = random.randrange(0, iw - patch_size + 1)
        iy = random.randrange(0, ih - patch_size + 1)

        img_blur_in = all_img_blur[i][iy:iy + patch_size, ix:ix + patch_size]
        img_sharp_in = all_img_sharp[i][iy:iy + patch_size, ix:ix + patch_size]

        if augment:
            aug = random.randrange(0, 8)
            img_blur_in, img_sharp_in = data_augument(img_blur_in, img_sharp_in, aug)

        blur_batch.append(img_blur_in)
        sharp_batch.append(img_sharp_in)

    blur_batch = np.array(blur_batch)
    sharp_batch = np.array(sharp_batch)

    return blur_batch, sharp_batch

util2.py

from PIL import Image
import numpy as np
import random
import os

fine_size = 512
miu = np.array([0.5, 0.5, 0.5])
std = np.array([0.5, 0.5, 0.5])
def preprocess(img):
    # #归一化
    # img_np = np.array(img, dtype=float) / 255.
    # r = (img_np[:, :, 0] - miu[0]) / std[0]
    # g = (img_np[:, :, 1] - miu[1]) / std[1]
    # b = (img_np[:, :, 2] - miu[2]) / std[2]
    # img_np_t = np.array([r, g, b])
    # print(img_np_t.shape)
    # 不进行归一化
    img_np_t = img
    # #随机裁剪
    W = img_np_t.shape[1]
    H = img_np_t.shape[2]
    Ws = np.random.randint(0, W - fine_size - 1, 1)[0]
    Hs = np.random.randint(0, H - fine_size - 1, 1)[0]
    img_np_t = img_np_t[:, Ws:Ws + fine_size, Hs:Hs + fine_size]
    # # 随即裁剪
    # ih, iw, _ = img_np_t.shape
    # ix = random.randrange(0, iw - fine_size + 1)
    # iy = random.randrange(0, ih - fine_size + 1)
    # #
    # img_np_t = img_np_t[Ws:Ws + fine_size, Hs:Hs + fine_size]

    return img_np_t
def test_preprocess(img):
  #归一化
    img_np = np.array(img, dtype=float) / 255.
    r = (img_np[:, :, 0] - miu[0]) / std[0]
    g = (img_np[:, :, 1] - miu[1]) / std[1]
    b = (img_np[:, :, 2] - miu[2]) / std[2]
    img_np_t = np.array([r, g, b])
    return img_np_t

def image_loader(image_path, load_x, load_y, is_train=True):
    imgs = sorted(os.listdir(image_path))
    img_list = []
    for ele in imgs:
        img = Image.open(os.path.join(image_path, ele)).convert('RGB')
        if is_train: #对训练图像进行预处理,原始:resize,未进行归一化
            img = img.resize((load_x, load_y), Image.BICUBIC)
            # img = preprocess(img)
        # else:
        #     img = test_preprocess(img)
        img_list.append(np.array(img))

    return img_list


def data_augument(lr_img, hr_img, aug):
    if aug < 4:
        lr_img = np.rot90(lr_img, aug)
        hr_img = np.rot90(hr_img, aug)

    elif aug == 4:
        lr_img = np.fliplr(lr_img)
        hr_img = np.fliplr(hr_img)

    elif aug == 5:
        lr_img = np.flipud(lr_img)
        hr_img = np.flipud(hr_img)

    elif aug == 6:
        lr_img = np.rot90(np.fliplr(lr_img))
        hr_img = np.rot90(np.fliplr(hr_img))

    elif aug == 7:
        lr_img = np.rot90(np.flipud(lr_img))
        hr_img = np.rot90(np.flipud(hr_img))

    return lr_img, hr_img

#这里又进行了随即裁剪,emmmm,如果是测试的时候怎么改,先把这部分注释掉?
def batch_gen(blur_imgs, sharp_imgs, patch_size, batch_size, random_index, step, augment=False):
    img_index = random_index[step * batch_size: (step + 1) * batch_size]

    all_img_blur = []
    all_img_sharp = []

    for _index in img_index:
        all_img_blur.append(blur_imgs[_index])
        all_img_sharp.append(sharp_imgs[_index])

    blur_batch = []
    sharp_batch = []

    for i in range(len(all_img_blur)):

        ih, iw, _ = all_img_blur[i].shape
        ix = random.randrange(0, iw - patch_size + 1)
        iy = random.randrange(0, ih - patch_size + 1)

        img_blur_in = all_img_blur[i][iy:iy + patch_size, ix:ix + patch_size]
        img_sharp_in = all_img_sharp[i][iy:iy + patch_size, ix:ix + patch_size]
        # img_blur_in = all_img_blur[i]
        # img_sharp_in = all_img_sharp[i]
        # if augment:
        #     aug = random.randrange(0, 8)
        #     img_blur_in, img_sharp_in = data_augument(img_blur_in, img_sharp_in, aug)

        blur_batch.append(img_blur_in)
        sharp_batch.append(img_sharp_in)

    blur_batch = np.array(blur_batch)
    sharp_batch = np.array(sharp_batch)
    # # transpose
    # blur_batch = np.transpose(blur_batch,[0,2,3,1])
    # sharp_batch = np.transpose(sharp_batch,[0,2,3,1])
    return blur_batch, sharp_batch

DeblurGAN.py

from layer import *
from data_loader import dataloader
from vgg19 import Vgg19



class DeblurGAN():

    def __init__(self, args):

        self.data_loader = dataloader(args)
        print("data has been loaded")

        self.channel = 3

        self.n_feats = args.n_feats
        self.mode = args.mode
        self.batch_size = args.batch_size
        self.num_of_down_scale = args.num_of_down_scale
        self.gen_resblocks = args.gen_resblocks
        self.discrim_blocks = args.discrim_blocks
        self.vgg_path = args.vgg_path

        self.learning_rate = args.learning_rate
        self.decay_step = args.decay_step

    def down_scaling_feature(self, name, x, n_feats):
        x = Conv(name=name + 'conv', x=x, filter_size=3, in_filters=n_feats,
                 out_filters=n_feats * 2, strides=2, padding='SAME')
        x = instance_norm(x)
        x = tf.nn.relu(x)

        return x

    def up_scaling_feature(self, name, x, n_feats):
        x = Conv_transpose(name=name + 'deconv', x=x, filter_size=3, in_filters=n_feats,
                           out_filters=n_feats // 2, fraction=2, padding='SAME')
        x = instance_norm(x)
        x = tf.nn.relu(x)

        return x

    def res_block(self, name, x, n_feats):

        _res = x

        x = tf.pad(x, [[0, 0], [1, 1], [1, 1], [0, 0]], mode='REFLECT')
        x = Conv(name=name + 'conv1', x=x, filter_size=3, in_filters=n_feats,
                 out_filters=n_feats, strides=1, padding='VALID')
        x = instance_norm(x)
        x = tf.nn.relu(x)

        x = tf.pad(x, [[0, 0], [1, 1], [1, 1], [0, 0]], mode='REFLECT')
        x = Conv(name=name + 'conv2', x=x, filter_size=3, in_filters=n_feats,
                 out_filters=n_feats, strides=1, padding='VALID')
        x = instance_norm(x)

        x = x + _res

        return x

    def generator(self, x, reuse=False, name='generator'):

        with tf.variable_scope(name_or_scope=name, reuse=reuse):
            _res = x
            x = tf.pad(x, [[0, 0], [3, 3], [3, 3], [0, 0]], mode='REFLECT')
            x = Conv(name='conv1', x=x, filter_size=7, in_filters=self.channel,
                     out_filters=self.n_feats, strides=1, padding='VALID')

            x = instance_norm(x)
            x = tf.nn.relu(x)

            for i in range(self.num_of_down_scale):
                x = self.down_scaling_feature(name='down_%02d' % i, x=x, n_feats=self.n_feats * (i + 1))

            for i in range(self.gen_resblocks):
                x = self.res_block(name='res_%02d' % i, x=x, n_feats=self.n_feats * (2 ** self.num_of_down_scale))

            for i in range(self.num_of_down_scale):
                x = self.up_scaling_feature(name='up_%02d' % i, x=x,
                                            n_feats=self.n_feats * (2 ** (self.num_of_down_scale - i)))

            x = tf.pad(x, [[0, 0], [3, 3], [3, 3], [0, 0]], mode='REFLECT')
            x = Conv(name='conv_last', x=x, filter_size=7, in_filters=self.n_feats,
                     out_filters=self.channel, strides=1, padding='VALID')
            x = tf.nn.tanh(x)
            x = x + _res
            x = tf.clip_by_value(x, -1.0, 1.0,name="clip")
            print(x)
            return x

    def discriminator(self, x, reuse=False, name='discriminator'):

        with tf.variable_scope(name_or_scope=name, reuse=reuse):
            x = Conv(name='conv1', x=x, filter_size=4, in_filters=self.channel,
                     out_filters=self.n_feats, strides=2, padding="SAME")
            x = instance_norm(x)
            x = tf.nn.leaky_relu(x)

            n = 1

            for i in range(self.discrim_blocks):
                prev = n
                n = min(2 ** (i + 1), 8)
                x = Conv(name='conv%02d' % i, x=x, filter_size=4, in_filters=self.n_feats * prev,
                         out_filters=self.n_feats * n, strides=2, padding="SAME")
                x = instance_norm(x)
                x = tf.nn.leaky_relu(x)

            prev = n
            n = min(2 ** self.discrim_blocks, 8)
            x = Conv(name='conv_d1', x=x, filter_size=4, in_filters=self.n_feats * prev,
                     out_filters=self.n_feats * n, strides=1, padding="SAME")
            # x = instance_norm(name = 'instance_norm_d1', x = x, dim = self.n_feats * n)
            x = instance_norm(x)
            x = tf.nn.leaky_relu(x)

            x = Conv(name='conv_d2', x=x, filter_size=4, in_filters=self.n_feats * n,
                     out_filters=1, strides=1, padding="SAME")
            x = tf.nn.sigmoid(x)

            return x

    def build_graph(self):
        # if self.in_memory:
        self.blur = tf.placeholder(name="blur", shape=[None, None, None, self.channel], dtype=tf.float32)
        self.sharp = tf.placeholder(name="sharp", shape=[None, None, None, self.channel], dtype=tf.float32)

        x = self.blur
        label = self.sharp

        self.epoch = tf.placeholder(name='train_step', shape=None, dtype=tf.int32)

        x = (2.0 * x / 255.0) - 1.0
        label = (2.0 * label / 255.0) - 1.0

        self.gene_img = self.generator(x, reuse=False)
        self.real_prob = self.discriminator(label, reuse=False)
        self.fake_prob = self.discriminator(self.gene_img, reuse=True)

        epsilon = tf.random_uniform(shape=[self.batch_size, 1, 1, 1], minval=0.0, maxval=1.0)

        interpolated_input = epsilon * label + (1 - epsilon) * self.gene_img
        gradient = tf.gradients(self.discriminator(interpolated_input, reuse=True), [interpolated_input])[0]
        GP_loss = tf.reduce_mean(tf.square(tf.sqrt(tf.reduce_mean(tf.square(gradient), axis=[1, 2, 3])) - 1))

        d_loss_real = - tf.reduce_mean(self.real_prob)
        d_loss_fake = tf.reduce_mean(self.fake_prob)

        self.vgg_net = Vgg19(self.vgg_path)
        self.vgg_net.build(tf.concat([label, self.gene_img], axis=0))
        self.content_loss = tf.reduce_mean(tf.reduce_sum(tf.square(
            self.vgg_net.relu3_3[self.batch_size:] - self.vgg_net.relu3_3[:self.batch_size]), axis=3))

        self.D_loss = d_loss_real + d_loss_fake + 10.0 * GP_loss
        self.G_loss = - d_loss_fake + 100.0 * self.content_loss

        t_vars = tf.trainable_variables()
        G_vars = [var for var in t_vars if 'generator' in var.name]
        D_vars = [var for var in t_vars if 'discriminator' in var.name]

        lr = tf.minimum(self.learning_rate, tf.abs(2 * self.learning_rate - (
                self.learning_rate * tf.cast(self.epoch, tf.float32) / self.decay_step)))
        self.D_train = tf.train.AdamOptimizer(learning_rate=lr).minimize(self.D_loss, var_list=D_vars)
        self.G_train = tf.train.AdamOptimizer(learning_rate=lr).minimize(self.G_loss, var_list=G_vars)

        self.PSNR = tf.reduce_mean(tf.image.psnr(((self.gene_img + 1.0) / 2.0), ((label + 1.0) / 2.0), max_val=1.0))
        self.ssim = tf.reduce_mean(tf.image.ssim(((self.gene_img + 1.0) / 2.0), ((label + 1.0) / 2.0), max_val=1.0))

        logging_D_loss = tf.summary.scalar(name='D_loss', tensor=self.D_loss)
        logging_G_loss = tf.summary.scalar(name='G_loss', tensor=self.G_loss)
        logging_PSNR = tf.summary.scalar(name='PSNR', tensor=self.PSNR)
        logging_ssim = tf.summary.scalar(name='ssim', tensor=self.ssim)

        self.output = (self.gene_img + 1.0) * 255.0 / 2.0
        self.output = tf.round(self.output)
        self.output = tf.cast(self.output, tf.uint8)

layer.py

import tensorflow as tf
import numpy as np


def Conv(name, x, filter_size, in_filters, out_filters, strides, padding):
    with tf.variable_scope(name):
        kernel = tf.get_variable('filter', [filter_size, filter_size, in_filters, out_filters], tf.float32,
                                 initializer=tf.random_normal_initializer(stddev=0.01))
        bias = tf.get_variable('bias', [out_filters], tf.float32, initializer=tf.zeros_initializer())

        return tf.nn.conv2d(x, kernel, [1, strides, strides, 1], padding=padding) + bias


def Conv_transpose(name, x, filter_size, in_filters, out_filters, fraction=2, padding="SAME"):
    with tf.variable_scope(name):
        n = filter_size * filter_size * out_filters
        kernel = tf.get_variable('filter', [filter_size, filter_size, out_filters, in_filters], tf.float32,
                                 initializer=tf.random_normal_initializer(stddev=np.sqrt(2.0 / n)))
        size = tf.shape(x)
        output_shape = tf.stack([size[0], size[1] * fraction, size[2] * fraction, out_filters])
        x = tf.nn.conv2d_transpose(x, kernel, output_shape, [1, fraction, fraction, 1], padding)

        return x


# def instance_norm(x, BN_epsilon=1e-3):
#     inputs_rank = x.shape.ndims
#     moments_axes = list(range(inputs_rank))
#     # mean, variance = tf.nn.moments(x, axes=[1, 2],keep_dims=True) #batch_size=1
#     mean, variance = tf.nn.moments(x, axes=moments_axes,keep_dims=True)
#     x = (x - mean) / ((variance + BN_epsilon) ** 0.5)
#     return x

def instance_norm(x):
    x_norm = tf.contrib.layers.instance_norm(x) #tf1.8
    return x_norm

vgg19.py

import tensorflow as tf
import numpy as np
import time

VGG_MEAN = [103.939, 116.779, 123.68]


class Vgg19:

    def __init__(self, vgg19_npy_path):
        self.data_dict = np.load(vgg19_npy_path, encoding='latin1').item()
        print("npy file loaded")

    def build(self, rgb):
        """
        load variable from npy to build the VGG
        :param rgb: rgb image [batch, height, width, 3] values scaled [-1, 1]
        """

        start_time = time.time()
        print("build vgg19 model started")
        rgb_scaled = ((rgb + 1) * 255.0) / 2.0

        # Convert RGB to BGR
        red, green, blue = tf.split(axis=3, num_or_size_splits=3, value=rgb_scaled)
        bgr = tf.concat(axis=3, values=[blue - VGG_MEAN[0], green - VGG_MEAN[1], red - VGG_MEAN[2]])

        self.conv1_1 = self.conv_layer(bgr, "conv1_1")
        self.relu1_1 = self.relu_layer(self.conv1_1, "relu1_1")
        self.conv1_2 = self.conv_layer(self.relu1_1, "conv1_2")
        self.relu1_2 = self.relu_layer(self.conv1_2, "relu1_2")
        self.pool1 = self.max_pool(self.relu1_2, 'pool1')

        self.conv2_1 = self.conv_layer(self.pool1, "conv2_1")
        self.relu2_1 = self.relu_layer(self.conv2_1, "relu2_1")
        self.conv2_2 = self.conv_layer(self.relu2_1, "conv2_2")
        self.relu2_2 = self.relu_layer(self.conv2_2, "relu2_2")
        self.pool2 = self.max_pool(self.relu2_2, 'pool2')

        self.conv3_1 = self.conv_layer(self.pool2, "conv3_1")
        self.relu3_1 = self.relu_layer(self.conv3_1, "relu3_1")
        self.conv3_2 = self.conv_layer(self.relu3_1, "conv3_2")
        self.relu3_2 = self.relu_layer(self.conv3_2, "relu3_2")
        self.conv3_3 = self.conv_layer(self.relu3_2, "conv3_3")
        self.relu3_3 = self.relu_layer(self.conv3_3, "relu3_3")
        self.conv3_4 = self.conv_layer(self.relu3_3, "conv3_4")
        self.relu3_4 = self.relu_layer(self.conv3_4, "relu3_4")
        self.pool3 = self.max_pool(self.relu3_4, 'pool3')

        self.conv4_1 = self.conv_layer(self.pool3, "conv4_1")
        self.relu4_1 = self.relu_layer(self.conv4_1, "relu4_1")
        self.conv4_2 = self.conv_layer(self.relu4_1, "conv4_2")
        self.relu4_2 = self.relu_layer(self.conv4_2, "relu4_2")
        self.conv4_3 = self.conv_layer(self.relu4_2, "conv4_3")
        self.relu4_3 = self.relu_layer(self.conv4_3, "relu4_3")
        self.conv4_4 = self.conv_layer(self.relu4_3, "conv4_4")
        self.relu4_4 = self.relu_layer(self.conv4_4, "relu4_4")
        self.pool4 = self.max_pool(self.relu4_4, 'pool4')

        self.conv5_1 = self.conv_layer(self.pool4, "conv5_1")
        self.relu5_1 = self.relu_layer(self.conv5_1, "relu5_1")
        self.conv5_2 = self.conv_layer(self.relu5_1, "conv5_2")
        self.relu5_2 = self.relu_layer(self.conv5_2, "relu5_2")
        self.conv5_3 = self.conv_layer(self.relu5_2, "conv5_3")
        self.relu5_3 = self.relu_layer(self.conv5_3, "relu5_3")
        self.conv5_4 = self.conv_layer(self.relu5_3, "conv5_4")
        self.relu5_4 = self.relu_layer(self.conv5_4, "relu5_4")
        self.pool5 = self.max_pool(self.conv5_4, 'pool5')

        self.data_dict = None
        print(("build vgg19 model finished: %ds" % (time.time() - start_time)))

    def max_pool(self, bottom, name):
        return tf.nn.max_pool(bottom, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME', name=name)

    def relu_layer(self, bottom, name):
        return tf.nn.relu(bottom, name=name)

    def conv_layer(self, bottom, name):
        with tf.variable_scope(name):
            filt = self.get_conv_filter(name)

            conv = tf.nn.conv2d(bottom, filt, [1, 1, 1, 1], padding='SAME')

            conv_biases = self.get_bias(name)
            bias = tf.nn.bias_add(conv, conv_biases)

            return bias

    def get_conv_filter(self, name):
        return tf.constant(self.data_dict[name][0], name="filter")

    def get_bias(self, name):
        return tf.constant(self.data_dict[name][1], name="biases")

  • 1
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值