参考链接:
https://blog.csdn.net/z704630835/article/details/84135277
https://github.com/dongheehand/DeblurGAN-tf
数据集:
代码实现:
data_loader.py
import tensorflow as tf
import os
class dataloader():
def __init__(self, args):
self.channel = 3
self.mode = args.mode
self.patch_size = args.patch_size
self.batch_size = args.batch_size
self.train_Sharp_path = args.train_Sharp_path
self.train_Blur_path = args.train_Blur_path
self.test_Sharp_path = args.test_Sharp_path
self.test_Blur_path = args.test_Blur_path
self.test_with_train = args.test_with_train
self.test_batch = args.test_batch
self.load_X = args.load_X
self.load_Y = args.load_Y
self.augmentation = args.augmentation
def build_loader(self):
if self.mode == 'train':
tr_sharp_imgs = sorted(os.listdir(self.train_Sharp_path))
tr_blur_imgs = sorted(os.listdir(self.train_Blur_path))
tr_sharp_imgs = [os.path.join(self.train_Sharp_path, ele) for ele in tr_sharp_imgs]
tr_blur_imgs = [os.path.join(self.train_Blur_path, ele) for ele in tr_blur_imgs]
train_list = (tr_blur_imgs, tr_sharp_imgs)
self.tr_dataset = tf.data.Dataset.from_tensor_slices(train_list)
self.tr_dataset = self.tr_dataset.map(self._parse, num_parallel_calls=4).prefetch(32)
self.tr_dataset = self.tr_dataset.map(self._resize, num_parallel_calls=4).prefetch(32)
self.tr_dataset = self.tr_dataset.map(self._get_patch, num_parallel_calls=4).prefetch(32)
if self.augmentation:
self.tr_dataset = self.tr_dataset.map(self._data_augmentation, num_parallel_calls=4).prefetch(32)
self.tr_dataset = self.tr_dataset.shuffle(32)
self.tr_dataset = self.tr_dataset.repeat()
self.tr_dataset = self.tr_dataset.batch(self.batch_size)
if self.test_with_train:
val_sharp_imgs = sorted(os.listdir(self.test_Sharp_path))
val_blur_imgs = sorted(os.listdir(self.test_Blur_path))
val_sharp_imgs = [os.path.join(self.test_Sharp_path, ele) for ele in val_sharp_imgs]
val_blur_imgs = [os.path.join(self.test_Blur_path, ele) for ele in val_blur_imgs]
valid_list = (val_blur_imgs, val_sharp_imgs)
self.val_dataset = tf.data.Dataset.from_tensor_slices(valid_list)
self.val_dataset = self.val_dataset.map(self._parse, num_parallel_calls=4).prefetch(32)
self.val_dataset = self.val_dataset.batch(self.test_batch)
iterator = tf.data.Iterator.from_structure(self.tr_dataset.output_types, self.tr_dataset.output_shapes)
self.next_batch = iterator.get_next()
self.init_op = {}
self.init_op['tr_init'] = iterator.make_initializer(self.tr_dataset)
if self.test_with_train:
self.init_op['val_init'] = iterator.make_initializer(self.val_dataset)
elif self.mode == 'test':
val_sharp_imgs = sorted(os.listdir(self.test_Sharp_path))
val_blur_imgs = sorted(os.listdir(self.test_Blur_path))
val_sharp_imgs = [os.path.join(self.test_Sharp_path, ele) for ele in val_sharp_imgs]
val_blur_imgs = [os.path.join(self.test_Blur_path, ele) for ele in val_blur_imgs]
valid_list = (val_blur_imgs, val_sharp_imgs)
self.val_dataset = tf.data.Dataset.from_tensor_slices(valid_list)
self.val_dataset = self.val_dataset.map(self._parse, num_parallel_calls=4).prefetch(32)
self.val_dataset = self.val_dataset.batch(1)
iterator = tf.data.Iterator.from_structure(self.val_dataset.output_types, self.val_dataset.output_shapes)
self.next_batch = iterator.get_next()
self.init_op = {}
self.init_op['val_init'] = iterator.make_initializer(self.val_dataset)
def _parse(self, image_blur, image_sharp):
image_blur = tf.read_file(image_blur)
image_sharp = tf.read_file(image_sharp)
image_blur = tf.image.decode_image(image_blur, channels=self.channel)
image_sharp = tf.image.decode_image(image_sharp, channels=self.channel)
image_blur = tf.cast(image_blur, tf.float32)
image_sharp = tf.cast(image_sharp, tf.float32)
return image_blur, image_sharp
def _resize(self, image_blur, image_sharp):
image_blur = tf.image.resize_images(image_blur, (self.load_Y, self.load_X), tf.image.ResizeMethod.BICUBIC)
image_sharp = tf.image.resize_images(image_sharp, (self.load_Y, self.load_X), tf.image.ResizeMethod.BICUBIC)
return image_blur, image_sharp
def _parse_Blur_only(self, image_blur):
image_blur = tf.read_file(image_blur)
image_blur = tf.image.decode_image(image_blur, channels=self.channel)
image_blur = tf.cast(image_blur, tf.float32)
return image_blur
def _get_patch(self, image_blur, image_sharp):
shape = tf.shape(image_blur)
ih = shape[0]
iw = shape[1]
ix = tf.random_uniform(shape=[1], minval=0, maxval=iw - self.patch_size + 1, dtype=tf.int32)[0]
iy = tf.random_uniform(shape=[1], minval=0, maxval=ih - self.patch_size + 1, dtype=tf.int32)[0]
img_sharp_in = image_sharp[iy:iy + self.patch_size, ix:ix + self.patch_size]
img_blur_in = image_blur[iy:iy + self.patch_size, ix:ix + self.patch_size]
return img_blur_in, img_sharp_in
def _data_augmentation(self, image_blur, image_sharp):
rot = tf.random_uniform(shape=[1], minval=0, maxval=3, dtype=tf.int32)[0]
flip_rl = tf.random_uniform(shape=[1], minval=0, maxval=3, dtype=tf.int32)[0]
flip_updown = tf.random_uniform(shape=[1], minval=0, maxval=3, dtype=tf.int32)[0]
image_blur = tf.image.rot90(image_blur, rot)
image_sharp = tf.image.rot90(image_sharp, rot)
rl = tf.equal(tf.mod(flip_rl, 2), 0)
ud = tf.equal(tf.mod(flip_updown, 2), 0)
image_blur = tf.cond(rl, true_fn=lambda: tf.image.flip_left_right(image_blur),
false_fn=lambda: image_blur)
image_sharp = tf.cond(rl, true_fn=lambda: tf.image.flip_left_right(image_sharp),
false_fn=lambda: image_sharp)
image_blur = tf.cond(ud, true_fn=lambda: tf.image.flip_up_down(image_blur),
false_fn=lambda: image_blur)
image_sharp = tf.cond(ud, true_fn=lambda: tf.image.flip_up_down(image_sharp),
false_fn=lambda: image_sharp)
return image_blur, image_sharp
main.py
import tensorflow as tf
from DeblurGAN import DeblurGAN
from mode import *
import argparse
parser = argparse.ArgumentParser()
def str2bool(v):
return v.lower() in ('true')
## Model specification
parser.add_argument("--n_feats", type=int, default=64)
parser.add_argument("--num_of_down_scale", type=int, default=2)
parser.add_argument("--gen_resblocks", type=int, default=9)
parser.add_argument("--discrim_blocks", type=int, default=3)
## Data specification
parser.add_argument("--train_Sharp_path", type=str, default="data/train/sharp/")
parser.add_argument("--train_Blur_path", type=str, default="data/train/blur/")
parser.add_argument("--test_Sharp_path", type=str, default="data/test/sharp/")
parser.add_argument("--test_Blur_path", type=str, default="data/test/blur/")
parser.add_argument("--vgg_path", type=str, default="C:/Users/ywx613838/.keras/models/vgg19.npy")
parser.add_argument("--patch_size", type=int, default=256) #training size
parser.add_argument("--result_path", type=str, default="./result")
parser.add_argument("--model_path", type=str, default="./model")
## Optimization
parser.add_argument("--batch_size", type=int, default=2) #1 training-For-4
parser.add_argument("--max_epoch", type=int, default=2)
parser.add_argument("--learning_rate", type=float, default=1e-4)
parser.add_argument("--decay_step", type=int, default=150)
parser.add_argument("--test_with_train", type=str2bool, default=False)
parser.add_argument("--save_test_result", type=str2bool, default=True)
## Training or test specification
parser.add_argument("--mode", type=str, default="train")
parser.add_argument("--critic_updates", type=int, default=5)
parser.add_argument("--augmentation", type=str2bool, default=False)
parser.add_argument("--load_X", type=int, default=640)
parser.add_argument("--load_Y", type=int, default=360)
parser.add_argument("--fine_tuning", type=str2bool, default=False)
parser.add_argument("--log_freq", type=int, default=1)
parser.add_argument("--model_save_freq", type=int, default=1)
parser.add_argument("--pre_trained_model", type=str, default="./model/")
parser.add_argument("--test_batch", type=int, default=2)
args = parser.parse_args()
model = DeblurGAN(args)
model.build_graph()
print("Build DeblurGAN model!")
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver(max_to_keep=None)
if args.mode == 'train':
train(args, model, sess, saver)
elif args.mode == 'test':
f = open("test_results.txt", 'w')
test(args, model, sess, saver, f, step=-1, loading=True)
f.close()
mode.py
import os
import tensorflow as tf
from PIL import Image
import numpy as np
import time
import util2
def train(args, model, sess, saver):
if args.fine_tuning:
saver.restore(sess, args.pre_trained_model)
print("saved model is loaded for fine-tuning!")
print("model path is %s" % args.pre_trained_model)
num_imgs = len(os.listdir(args.train_Sharp_path))
merged = tf.summary.merge_all()
train_writer = tf.summary.FileWriter('./logs', sess.graph)
if args.test_with_train:
f = open("valid_logs.txt", 'w')
epoch = 0
step = num_imgs // args.batch_size
blur_imgs = util2.image_loader(args.train_Blur_path, args.load_X, args.load_Y)
sharp_imgs = util2.image_loader(args.train_Sharp_path, args.load_X, args.load_Y)
while epoch < args.max_epoch:
random_index = np.random.permutation(len(blur_imgs))
for k in range(step):
s_time = time.time()
blur_batch, sharp_batch = util2.batch_gen(blur_imgs, sharp_imgs, args.patch_size,
args.batch_size, random_index, k)
for t in range(args.critic_updates):
_, D_loss = sess.run([model.D_train, model.D_loss],
feed_dict={model.blur: blur_batch, model.sharp: sharp_batch, model.epoch: epoch})
_, G_loss = sess.run([model.G_train, model.G_loss],
feed_dict={model.blur: blur_batch, model.sharp: sharp_batch, model.epoch: epoch})
e_time = time.time()
if epoch % args.log_freq == 0:
summary = sess.run(merged, feed_dict={model.blur: blur_batch, model.sharp: sharp_batch})
train_writer.add_summary(summary, epoch)
if args.test_with_train:
test(args, model, sess, saver, f, epoch, loading=False)
print("%d training epoch completed" % epoch)
print("D_loss : {}, \t G_loss : {}".format(D_loss, G_loss))
print("Elpased time : %0.4f" % (e_time - s_time))
# print("D_loss : %0.4f, \t G_loss : %0.4f" % (D_loss, G_loss))
# print("Elpased time : %0.4f" % (e_time - s_time))
if (epoch) % args.model_save_freq == 0:
saver.save(sess, './model/DeblurrGAN', global_step=epoch, write_meta_graph=True)
epoch += 1
saver.save(sess, './model/DeblurrGAN_last', write_meta_graph=True)
if args.test_with_train:
f.close()
def test(args, model, sess, saver, file, step=-1, loading=False):
if loading:
import re
print(" [*] Reading checkpoints...")
ckpt = tf.train.get_checkpoint_state(args.pre_trained_model)
if ckpt and ckpt.model_checkpoint_path:
ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
saver.restore(sess, os.path.join(args.pre_trained_model, ckpt_name))
print(" [*] Success to read {}".format(ckpt_name))
else:
print(" [*] Failed to find a checkpoint")
blur_img_name = sorted(os.listdir(args.test_Blur_path))
sharp_img_name = sorted(os.listdir(args.test_Sharp_path))
PSNR_list = []
ssim_list = []
blur_imgs = util2.image_loader(args.test_Blur_path, args.load_X, args.load_Y, is_train=False)
sharp_imgs = util2.image_loader(args.test_Sharp_path, args.load_X, args.load_Y, is_train=False)
if not os.path.exists('./result/'):
os.makedirs('./result/')
for i, ele in enumerate(blur_imgs):
blur = np.expand_dims(ele, axis=0)
sharp = np.expand_dims(sharp_imgs[i], axis=0)
# blur = np.transpose(blur, [0, 2, 3, 1])
# sharp = np.transpose(sharp, [0, 2, 3, 1])
output, psnr, ssim = sess.run([model.output, model.PSNR, model.ssim],
feed_dict={model.blur: blur, model.sharp: sharp})
if args.save_test_result:
#逆归一化
# output = np.array((output[0] + 1) / 2.0 * 255.0).astype(np.uint8)
output = Image.fromarray(output[0])
split_name = blur_img_name[i].split('.')
output.save(os.path.join(args.result_path, '%s_deblur.png' % (''.join(map(str, split_name[:-1])))))
PSNR_list.append(psnr)
ssim_list.append(ssim)
length = len(PSNR_list)
mean_PSNR = sum(PSNR_list) / length
mean_ssim = sum(ssim_list) / length
if step == -1:
file.write('PSNR : {} SSIM : {}'.format(mean_PSNR, mean_ssim))
file.close()
else:
file.write("{}d-epoch step PSNR : {} SSIM : {} \n".format(step, mean_PSNR, mean_ssim))
toPB.py
import tensorflow as tf
import os
from tensorflow.python.framework import graph_util
from tensorflow.python import pywrap_tensorflow
def freeze_graph(input_checkpoint, output_graph):
'''
:param input_checkpoint:
:param output_graph: PB模型保存路径
:return:
'''
# checkpoint = tf.train.get_checkpoint_state(model_folder) #检查目录下ckpt文件状态是否可用
# input_checkpoint = checkpoint.model_checkpoint_path #得ckpt文件路径
# 指定输出的节点名称,该节点名称必须是原模型中存在的节点
output_node_names = "generator/clip"
saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
with tf.Session() as sess:
saver.restore(sess, input_checkpoint) # 恢复图并得到数据
output_graph_def = graph_util.convert_variables_to_constants( # 模型持久化,将变量值固定
sess=sess,
input_graph_def=sess.graph_def, # 等于:sess.graph_def
output_node_names=output_node_names.split(",")) # 如果有多个输出节点,以逗号隔开
with tf.gfile.GFile(output_graph, "wb") as f: # 保存模型
f.write(output_graph_def.SerializeToString()) # 序列化输出
print("%d ops in the final graph." % len(output_graph_def.node)) # 得到当前图有几个操作节点
if __name__ == '__main__':
input_checkpoint = './'
# # 查看节点名称:
# reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint)
# var_to_shape_map = reader.get_variable_to_shape_map()
# for key in var_to_shape_map:
# print("tensor_name: ", key)
# 输出pb模型的路径
out_pb_path = "./frozen_model.pb"
# 调用freeze_graph将ckpt转为pb
freeze_graph(input_checkpoint, out_pb_path)
test_pb.py
import tensorflow as tf
import os
import cv2
from tensorflow.python.framework import graph_util
import numpy as np
from PIL import Image
def get_RAC(file_path, save_path, model_path):
with tf.Graph().as_default():
output_graph_def = tf.GraphDef()
with open(model_path, "rb") as f:
output_graph_def.ParseFromString(f.read())
_ = tf.import_graph_def(output_graph_def, name="")
with tf.Session() as sess:
init = tf.global_variables_initializer()
sess.run(init)
# 图片加载
img=Image.open(file_path).convert('RGB')
img=np.array(img)
img=np.expand_dims(img, axis=0)
input_x = sess.graph.get_tensor_by_name("blur:0")
output = sess.graph.get_tensor_by_name("generator/clip:0")
print("output:",output)
output_image = sess.run(output, feed_dict={input_x: img})[0]
print("output_img:",output_image.shape)
output_image=(output_image + 1.0) * 255.0 / 2.0
output_image = np.array(output_image).astype(np.uint8)
output_image = Image.fromarray(output_image)
output_image.save(save_path)
#模型路径
model_path = 'model/0000/test.pb'
#用于测试的原图片路径---批量
path1='data/test/blur/'
#测试生成的图片保存路径
path2 = 'data/result/'
#遍历原图片,单张生成测试图片
for file in os.listdir(path1):
file_path=path1+file
save_path=path2+file
get_RAC(file_path, save_path, model_path)
util.py
from PIL import Image
import numpy as np
import random
import os
def image_loader(image_path, load_x, load_y, is_train=True):
imgs = sorted(os.listdir(image_path))
img_list = []
for ele in imgs:
img = Image.open(os.path.join(image_path, ele))
if is_train:
img = img.resize((load_x, load_y), Image.BICUBIC)
img_list.append(np.array(img))
return img_list
def data_augument(lr_img, hr_img, aug):
if aug < 4:
lr_img = np.rot90(lr_img, aug)
hr_img = np.rot90(hr_img, aug)
elif aug == 4:
lr_img = np.fliplr(lr_img)
hr_img = np.fliplr(hr_img)
elif aug == 5:
lr_img = np.flipud(lr_img)
hr_img = np.flipud(hr_img)
elif aug == 6:
lr_img = np.rot90(np.fliplr(lr_img))
hr_img = np.rot90(np.fliplr(hr_img))
elif aug == 7:
lr_img = np.rot90(np.flipud(lr_img))
hr_img = np.rot90(np.flipud(hr_img))
return lr_img, hr_img
def batch_gen(blur_imgs, sharp_imgs, patch_size, batch_size, random_index, step, augment=False):
img_index = random_index[step * batch_size: (step + 1) * batch_size]
all_img_blur = []
all_img_sharp = []
for _index in img_index:
all_img_blur.append(blur_imgs[_index])
all_img_sharp.append(sharp_imgs[_index])
blur_batch = []
sharp_batch = []
for i in range(len(all_img_blur)):
ih, iw, _ = all_img_blur[i].shape
ix = random.randrange(0, iw - patch_size + 1)
iy = random.randrange(0, ih - patch_size + 1)
img_blur_in = all_img_blur[i][iy:iy + patch_size, ix:ix + patch_size]
img_sharp_in = all_img_sharp[i][iy:iy + patch_size, ix:ix + patch_size]
if augment:
aug = random.randrange(0, 8)
img_blur_in, img_sharp_in = data_augument(img_blur_in, img_sharp_in, aug)
blur_batch.append(img_blur_in)
sharp_batch.append(img_sharp_in)
blur_batch = np.array(blur_batch)
sharp_batch = np.array(sharp_batch)
return blur_batch, sharp_batch
util2.py
from PIL import Image
import numpy as np
import random
import os
fine_size = 512
miu = np.array([0.5, 0.5, 0.5])
std = np.array([0.5, 0.5, 0.5])
def preprocess(img):
# #归一化
# img_np = np.array(img, dtype=float) / 255.
# r = (img_np[:, :, 0] - miu[0]) / std[0]
# g = (img_np[:, :, 1] - miu[1]) / std[1]
# b = (img_np[:, :, 2] - miu[2]) / std[2]
# img_np_t = np.array([r, g, b])
# print(img_np_t.shape)
# 不进行归一化
img_np_t = img
# #随机裁剪
W = img_np_t.shape[1]
H = img_np_t.shape[2]
Ws = np.random.randint(0, W - fine_size - 1, 1)[0]
Hs = np.random.randint(0, H - fine_size - 1, 1)[0]
img_np_t = img_np_t[:, Ws:Ws + fine_size, Hs:Hs + fine_size]
# # 随即裁剪
# ih, iw, _ = img_np_t.shape
# ix = random.randrange(0, iw - fine_size + 1)
# iy = random.randrange(0, ih - fine_size + 1)
# #
# img_np_t = img_np_t[Ws:Ws + fine_size, Hs:Hs + fine_size]
return img_np_t
def test_preprocess(img):
#归一化
img_np = np.array(img, dtype=float) / 255.
r = (img_np[:, :, 0] - miu[0]) / std[0]
g = (img_np[:, :, 1] - miu[1]) / std[1]
b = (img_np[:, :, 2] - miu[2]) / std[2]
img_np_t = np.array([r, g, b])
return img_np_t
def image_loader(image_path, load_x, load_y, is_train=True):
imgs = sorted(os.listdir(image_path))
img_list = []
for ele in imgs:
img = Image.open(os.path.join(image_path, ele)).convert('RGB')
if is_train: #对训练图像进行预处理,原始:resize,未进行归一化
img = img.resize((load_x, load_y), Image.BICUBIC)
# img = preprocess(img)
# else:
# img = test_preprocess(img)
img_list.append(np.array(img))
return img_list
def data_augument(lr_img, hr_img, aug):
if aug < 4:
lr_img = np.rot90(lr_img, aug)
hr_img = np.rot90(hr_img, aug)
elif aug == 4:
lr_img = np.fliplr(lr_img)
hr_img = np.fliplr(hr_img)
elif aug == 5:
lr_img = np.flipud(lr_img)
hr_img = np.flipud(hr_img)
elif aug == 6:
lr_img = np.rot90(np.fliplr(lr_img))
hr_img = np.rot90(np.fliplr(hr_img))
elif aug == 7:
lr_img = np.rot90(np.flipud(lr_img))
hr_img = np.rot90(np.flipud(hr_img))
return lr_img, hr_img
#这里又进行了随即裁剪,emmmm,如果是测试的时候怎么改,先把这部分注释掉?
def batch_gen(blur_imgs, sharp_imgs, patch_size, batch_size, random_index, step, augment=False):
img_index = random_index[step * batch_size: (step + 1) * batch_size]
all_img_blur = []
all_img_sharp = []
for _index in img_index:
all_img_blur.append(blur_imgs[_index])
all_img_sharp.append(sharp_imgs[_index])
blur_batch = []
sharp_batch = []
for i in range(len(all_img_blur)):
ih, iw, _ = all_img_blur[i].shape
ix = random.randrange(0, iw - patch_size + 1)
iy = random.randrange(0, ih - patch_size + 1)
img_blur_in = all_img_blur[i][iy:iy + patch_size, ix:ix + patch_size]
img_sharp_in = all_img_sharp[i][iy:iy + patch_size, ix:ix + patch_size]
# img_blur_in = all_img_blur[i]
# img_sharp_in = all_img_sharp[i]
# if augment:
# aug = random.randrange(0, 8)
# img_blur_in, img_sharp_in = data_augument(img_blur_in, img_sharp_in, aug)
blur_batch.append(img_blur_in)
sharp_batch.append(img_sharp_in)
blur_batch = np.array(blur_batch)
sharp_batch = np.array(sharp_batch)
# # transpose
# blur_batch = np.transpose(blur_batch,[0,2,3,1])
# sharp_batch = np.transpose(sharp_batch,[0,2,3,1])
return blur_batch, sharp_batch
DeblurGAN.py
from layer import *
from data_loader import dataloader
from vgg19 import Vgg19
class DeblurGAN():
def __init__(self, args):
self.data_loader = dataloader(args)
print("data has been loaded")
self.channel = 3
self.n_feats = args.n_feats
self.mode = args.mode
self.batch_size = args.batch_size
self.num_of_down_scale = args.num_of_down_scale
self.gen_resblocks = args.gen_resblocks
self.discrim_blocks = args.discrim_blocks
self.vgg_path = args.vgg_path
self.learning_rate = args.learning_rate
self.decay_step = args.decay_step
def down_scaling_feature(self, name, x, n_feats):
x = Conv(name=name + 'conv', x=x, filter_size=3, in_filters=n_feats,
out_filters=n_feats * 2, strides=2, padding='SAME')
x = instance_norm(x)
x = tf.nn.relu(x)
return x
def up_scaling_feature(self, name, x, n_feats):
x = Conv_transpose(name=name + 'deconv', x=x, filter_size=3, in_filters=n_feats,
out_filters=n_feats // 2, fraction=2, padding='SAME')
x = instance_norm(x)
x = tf.nn.relu(x)
return x
def res_block(self, name, x, n_feats):
_res = x
x = tf.pad(x, [[0, 0], [1, 1], [1, 1], [0, 0]], mode='REFLECT')
x = Conv(name=name + 'conv1', x=x, filter_size=3, in_filters=n_feats,
out_filters=n_feats, strides=1, padding='VALID')
x = instance_norm(x)
x = tf.nn.relu(x)
x = tf.pad(x, [[0, 0], [1, 1], [1, 1], [0, 0]], mode='REFLECT')
x = Conv(name=name + 'conv2', x=x, filter_size=3, in_filters=n_feats,
out_filters=n_feats, strides=1, padding='VALID')
x = instance_norm(x)
x = x + _res
return x
def generator(self, x, reuse=False, name='generator'):
with tf.variable_scope(name_or_scope=name, reuse=reuse):
_res = x
x = tf.pad(x, [[0, 0], [3, 3], [3, 3], [0, 0]], mode='REFLECT')
x = Conv(name='conv1', x=x, filter_size=7, in_filters=self.channel,
out_filters=self.n_feats, strides=1, padding='VALID')
x = instance_norm(x)
x = tf.nn.relu(x)
for i in range(self.num_of_down_scale):
x = self.down_scaling_feature(name='down_%02d' % i, x=x, n_feats=self.n_feats * (i + 1))
for i in range(self.gen_resblocks):
x = self.res_block(name='res_%02d' % i, x=x, n_feats=self.n_feats * (2 ** self.num_of_down_scale))
for i in range(self.num_of_down_scale):
x = self.up_scaling_feature(name='up_%02d' % i, x=x,
n_feats=self.n_feats * (2 ** (self.num_of_down_scale - i)))
x = tf.pad(x, [[0, 0], [3, 3], [3, 3], [0, 0]], mode='REFLECT')
x = Conv(name='conv_last', x=x, filter_size=7, in_filters=self.n_feats,
out_filters=self.channel, strides=1, padding='VALID')
x = tf.nn.tanh(x)
x = x + _res
x = tf.clip_by_value(x, -1.0, 1.0,name="clip")
print(x)
return x
def discriminator(self, x, reuse=False, name='discriminator'):
with tf.variable_scope(name_or_scope=name, reuse=reuse):
x = Conv(name='conv1', x=x, filter_size=4, in_filters=self.channel,
out_filters=self.n_feats, strides=2, padding="SAME")
x = instance_norm(x)
x = tf.nn.leaky_relu(x)
n = 1
for i in range(self.discrim_blocks):
prev = n
n = min(2 ** (i + 1), 8)
x = Conv(name='conv%02d' % i, x=x, filter_size=4, in_filters=self.n_feats * prev,
out_filters=self.n_feats * n, strides=2, padding="SAME")
x = instance_norm(x)
x = tf.nn.leaky_relu(x)
prev = n
n = min(2 ** self.discrim_blocks, 8)
x = Conv(name='conv_d1', x=x, filter_size=4, in_filters=self.n_feats * prev,
out_filters=self.n_feats * n, strides=1, padding="SAME")
# x = instance_norm(name = 'instance_norm_d1', x = x, dim = self.n_feats * n)
x = instance_norm(x)
x = tf.nn.leaky_relu(x)
x = Conv(name='conv_d2', x=x, filter_size=4, in_filters=self.n_feats * n,
out_filters=1, strides=1, padding="SAME")
x = tf.nn.sigmoid(x)
return x
def build_graph(self):
# if self.in_memory:
self.blur = tf.placeholder(name="blur", shape=[None, None, None, self.channel], dtype=tf.float32)
self.sharp = tf.placeholder(name="sharp", shape=[None, None, None, self.channel], dtype=tf.float32)
x = self.blur
label = self.sharp
self.epoch = tf.placeholder(name='train_step', shape=None, dtype=tf.int32)
x = (2.0 * x / 255.0) - 1.0
label = (2.0 * label / 255.0) - 1.0
self.gene_img = self.generator(x, reuse=False)
self.real_prob = self.discriminator(label, reuse=False)
self.fake_prob = self.discriminator(self.gene_img, reuse=True)
epsilon = tf.random_uniform(shape=[self.batch_size, 1, 1, 1], minval=0.0, maxval=1.0)
interpolated_input = epsilon * label + (1 - epsilon) * self.gene_img
gradient = tf.gradients(self.discriminator(interpolated_input, reuse=True), [interpolated_input])[0]
GP_loss = tf.reduce_mean(tf.square(tf.sqrt(tf.reduce_mean(tf.square(gradient), axis=[1, 2, 3])) - 1))
d_loss_real = - tf.reduce_mean(self.real_prob)
d_loss_fake = tf.reduce_mean(self.fake_prob)
self.vgg_net = Vgg19(self.vgg_path)
self.vgg_net.build(tf.concat([label, self.gene_img], axis=0))
self.content_loss = tf.reduce_mean(tf.reduce_sum(tf.square(
self.vgg_net.relu3_3[self.batch_size:] - self.vgg_net.relu3_3[:self.batch_size]), axis=3))
self.D_loss = d_loss_real + d_loss_fake + 10.0 * GP_loss
self.G_loss = - d_loss_fake + 100.0 * self.content_loss
t_vars = tf.trainable_variables()
G_vars = [var for var in t_vars if 'generator' in var.name]
D_vars = [var for var in t_vars if 'discriminator' in var.name]
lr = tf.minimum(self.learning_rate, tf.abs(2 * self.learning_rate - (
self.learning_rate * tf.cast(self.epoch, tf.float32) / self.decay_step)))
self.D_train = tf.train.AdamOptimizer(learning_rate=lr).minimize(self.D_loss, var_list=D_vars)
self.G_train = tf.train.AdamOptimizer(learning_rate=lr).minimize(self.G_loss, var_list=G_vars)
self.PSNR = tf.reduce_mean(tf.image.psnr(((self.gene_img + 1.0) / 2.0), ((label + 1.0) / 2.0), max_val=1.0))
self.ssim = tf.reduce_mean(tf.image.ssim(((self.gene_img + 1.0) / 2.0), ((label + 1.0) / 2.0), max_val=1.0))
logging_D_loss = tf.summary.scalar(name='D_loss', tensor=self.D_loss)
logging_G_loss = tf.summary.scalar(name='G_loss', tensor=self.G_loss)
logging_PSNR = tf.summary.scalar(name='PSNR', tensor=self.PSNR)
logging_ssim = tf.summary.scalar(name='ssim', tensor=self.ssim)
self.output = (self.gene_img + 1.0) * 255.0 / 2.0
self.output = tf.round(self.output)
self.output = tf.cast(self.output, tf.uint8)
layer.py
import tensorflow as tf
import numpy as np
def Conv(name, x, filter_size, in_filters, out_filters, strides, padding):
with tf.variable_scope(name):
kernel = tf.get_variable('filter', [filter_size, filter_size, in_filters, out_filters], tf.float32,
initializer=tf.random_normal_initializer(stddev=0.01))
bias = tf.get_variable('bias', [out_filters], tf.float32, initializer=tf.zeros_initializer())
return tf.nn.conv2d(x, kernel, [1, strides, strides, 1], padding=padding) + bias
def Conv_transpose(name, x, filter_size, in_filters, out_filters, fraction=2, padding="SAME"):
with tf.variable_scope(name):
n = filter_size * filter_size * out_filters
kernel = tf.get_variable('filter', [filter_size, filter_size, out_filters, in_filters], tf.float32,
initializer=tf.random_normal_initializer(stddev=np.sqrt(2.0 / n)))
size = tf.shape(x)
output_shape = tf.stack([size[0], size[1] * fraction, size[2] * fraction, out_filters])
x = tf.nn.conv2d_transpose(x, kernel, output_shape, [1, fraction, fraction, 1], padding)
return x
# def instance_norm(x, BN_epsilon=1e-3):
# inputs_rank = x.shape.ndims
# moments_axes = list(range(inputs_rank))
# # mean, variance = tf.nn.moments(x, axes=[1, 2],keep_dims=True) #batch_size=1
# mean, variance = tf.nn.moments(x, axes=moments_axes,keep_dims=True)
# x = (x - mean) / ((variance + BN_epsilon) ** 0.5)
# return x
def instance_norm(x):
x_norm = tf.contrib.layers.instance_norm(x) #tf1.8
return x_norm
vgg19.py
import tensorflow as tf
import numpy as np
import time
VGG_MEAN = [103.939, 116.779, 123.68]
class Vgg19:
def __init__(self, vgg19_npy_path):
self.data_dict = np.load(vgg19_npy_path, encoding='latin1').item()
print("npy file loaded")
def build(self, rgb):
"""
load variable from npy to build the VGG
:param rgb: rgb image [batch, height, width, 3] values scaled [-1, 1]
"""
start_time = time.time()
print("build vgg19 model started")
rgb_scaled = ((rgb + 1) * 255.0) / 2.0
# Convert RGB to BGR
red, green, blue = tf.split(axis=3, num_or_size_splits=3, value=rgb_scaled)
bgr = tf.concat(axis=3, values=[blue - VGG_MEAN[0], green - VGG_MEAN[1], red - VGG_MEAN[2]])
self.conv1_1 = self.conv_layer(bgr, "conv1_1")
self.relu1_1 = self.relu_layer(self.conv1_1, "relu1_1")
self.conv1_2 = self.conv_layer(self.relu1_1, "conv1_2")
self.relu1_2 = self.relu_layer(self.conv1_2, "relu1_2")
self.pool1 = self.max_pool(self.relu1_2, 'pool1')
self.conv2_1 = self.conv_layer(self.pool1, "conv2_1")
self.relu2_1 = self.relu_layer(self.conv2_1, "relu2_1")
self.conv2_2 = self.conv_layer(self.relu2_1, "conv2_2")
self.relu2_2 = self.relu_layer(self.conv2_2, "relu2_2")
self.pool2 = self.max_pool(self.relu2_2, 'pool2')
self.conv3_1 = self.conv_layer(self.pool2, "conv3_1")
self.relu3_1 = self.relu_layer(self.conv3_1, "relu3_1")
self.conv3_2 = self.conv_layer(self.relu3_1, "conv3_2")
self.relu3_2 = self.relu_layer(self.conv3_2, "relu3_2")
self.conv3_3 = self.conv_layer(self.relu3_2, "conv3_3")
self.relu3_3 = self.relu_layer(self.conv3_3, "relu3_3")
self.conv3_4 = self.conv_layer(self.relu3_3, "conv3_4")
self.relu3_4 = self.relu_layer(self.conv3_4, "relu3_4")
self.pool3 = self.max_pool(self.relu3_4, 'pool3')
self.conv4_1 = self.conv_layer(self.pool3, "conv4_1")
self.relu4_1 = self.relu_layer(self.conv4_1, "relu4_1")
self.conv4_2 = self.conv_layer(self.relu4_1, "conv4_2")
self.relu4_2 = self.relu_layer(self.conv4_2, "relu4_2")
self.conv4_3 = self.conv_layer(self.relu4_2, "conv4_3")
self.relu4_3 = self.relu_layer(self.conv4_3, "relu4_3")
self.conv4_4 = self.conv_layer(self.relu4_3, "conv4_4")
self.relu4_4 = self.relu_layer(self.conv4_4, "relu4_4")
self.pool4 = self.max_pool(self.relu4_4, 'pool4')
self.conv5_1 = self.conv_layer(self.pool4, "conv5_1")
self.relu5_1 = self.relu_layer(self.conv5_1, "relu5_1")
self.conv5_2 = self.conv_layer(self.relu5_1, "conv5_2")
self.relu5_2 = self.relu_layer(self.conv5_2, "relu5_2")
self.conv5_3 = self.conv_layer(self.relu5_2, "conv5_3")
self.relu5_3 = self.relu_layer(self.conv5_3, "relu5_3")
self.conv5_4 = self.conv_layer(self.relu5_3, "conv5_4")
self.relu5_4 = self.relu_layer(self.conv5_4, "relu5_4")
self.pool5 = self.max_pool(self.conv5_4, 'pool5')
self.data_dict = None
print(("build vgg19 model finished: %ds" % (time.time() - start_time)))
def max_pool(self, bottom, name):
return tf.nn.max_pool(bottom, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME', name=name)
def relu_layer(self, bottom, name):
return tf.nn.relu(bottom, name=name)
def conv_layer(self, bottom, name):
with tf.variable_scope(name):
filt = self.get_conv_filter(name)
conv = tf.nn.conv2d(bottom, filt, [1, 1, 1, 1], padding='SAME')
conv_biases = self.get_bias(name)
bias = tf.nn.bias_add(conv, conv_biases)
return bias
def get_conv_filter(self, name):
return tf.constant(self.data_dict[name][0], name="filter")
def get_bias(self, name):
return tf.constant(self.data_dict[name][1], name="biases")