论文pdf 地址:https://arxiv.org/pdf/1609.04802v1.pdf
我的实际效果
清晰度距离我的期待有距离。
颜色上面存在差距。
解决想法
增加一个颜色判别器。将颜色值反馈给生成器
srgan论文是建立在gan基础上的,利用gan生成式对抗网络,将图片重构为高清分辨率的图片。
github上有开源的srgan项目。由于开源者,开发时考虑的问题更丰富,技巧更为高明,导致其代码都比较难以阅读和理解。
在为了充分理解这个论文。这里结合论文,开源代码,和自己的理解重新写了个srgan高清分辨率模型。
GAN原理
在一个不断提高判断能力的判断器的持续反馈下,不断改善生成器的生成参数,直到生成器生成的结果能够通过判断器的判断。(见本博客其他文章)
SRGAN用到的模块,及其关系
损失值,根据的这个关系结构计算的。
注意:vgg19是使用已经训练好的模型,这里只是拿来提取特征使用,
对于生成器,根据三个运算结果数据,进行随机梯度的优化调整
①判定器生成数据的鉴定结果
②vgg19的特征比较情况
③生成图形与理想图形的mse差距
论文中,生成器和判别器的模型图
生成器结构为:一层卷积,16层残差卷积,再将第一层卷积结果+16层残差结,卷积+2倍反卷积,卷积+2倍反卷积,tanh缩放,产生生成结果。
判别器结构为:8层卷积+reshape,全连接。(论文中,用了两层。我这里只用了一层全连接,参数量太大,我6G 的gpu内存不够用)
vgg19结构:在vgg19的第四层,返回获取到的特征结果,进行MSE对比
注意:BN处理,leaky relu等等处理技巧
代码解释
import numpy as np
import os
import tensorlayer as tl
import tensorflow as tf
#获取vgg9.npy中vgg19的参数,
vgg19_npy_path = "./vgg19.npy"
if not os.path.isfile(vgg19_npy_path):
print("Please download vgg19.npz from : https://github.com/machrisaa/tensorflow-vgg")
exit()
npz = np.load(vgg19_npy_path, encoding='latin1').item()
w_params = []
b_params = []
for val in sorted(npz.items()):
W = np.asarray(val[1][0])
b = np.asarray(val[1][1])
# print(" Loading %s: %s, %s" % (val[0], W.shape, b.shape))
w_params.append(W, )
b_params.extend(b)
#tensorlayer加载图片时,用于处理图片。随机获取图片中 192*192的矩阵, 内存不足时,可以优化这里
def crop_sub_imgs_fn(x, is_random=True):
x = tl.prepro.crop(x, wrg=192, hrg=192, is_random=is_random)
x = x / (255. / 2.)
x = x - 1.
return x
#resize矩阵 内存不足时,可以优化这里
def downsample_fn(x):
x = tl.prepro.imresize(x, size=[48, 48], interp='bicubic', mode=None)
x = x / (255. / 2.)
x = x - 1.
return x
# 参数
config = {
"epoch": 5,
}
# 内存不够时,可以减小这个
batch_size = 10
class SRGAN(object):
def __init__(self):
# with tf.device('/gpu:0'):
#占位变量,存储需要重构的图片
self.x = tf.placeholder(tf.float32, shape=[batch_size, 48, 48, 3], name='train_bechanged')
#占位变量,存储需要学习的理想中的图片
self.y = tf.placeholder(tf.float32, shape=[batch_size, 192, 192, 3], name='train_target')
self.init_fake_y = self.generator(self.x) # 预训练时生成的假照片
self.fake_y = self.generator(self.x, reuse=True) # 全部训练时生成的假照片
#占位变量,存储需要重构的测试图片
self.test_x = tf.placeholder(tf.float32, shape=[1, None, None, 3], name='test_generator')
#占位变量,存储重构后的测试图片
self.test_fake_y = self.generator(self.test_x, reuse=True) # 生成的假照片
#占位变量,将生成图片resize
self.fake_y_vgg = tf.image.resize_images(
self.fake_y, size=[224, 224], method=0,
align_corners=False)
#占位变量,将理想图片resize
self.real_y_vgg = tf.image.resize_images(
self.y, size=[224, 224], method=0,
align_corners=False)
#提取伪造图片的特征
self.fake_y_feature = self.vgg19(self.fake_y_vgg) # 假照片的特征值
#提取理想图片的特征
self.real_y_feature = self.vgg19(self.real_y_vgg, reuse=True) # 真照片的特征值
# self.pre_dis_logits = self.discriminator(self.fake_y) # 判别器生成的预测照片的判别值
self.fake_dis_logits = self.discriminator(self.fake_y, reuse=False) # 判别器生成的假照片的判别值
self.real_dis_logits = self.discriminator(self.y, reuse=True) # 判别器生成的假照片的判别值
# 预训练时,判别器的优化根据值
self.init_mse_loss = tf.losses.mean_squared_error(self.init_fake_y, self.y)
# 关于判别器的优化根据值
self.D_loos = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.real_dis_logits,
labels=tf.ones_like(
self.real_dis_logits))) + \
tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.fake_dis_logits,
labels=tf.zeros_like(
self.fake_dis_logits)))
# 伪造数据判别器的判断情况,生成与目标图像的差距,生成特征与理想特征的差距
self.D_loos_Ge = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.fake_dis_logits, labels=tf.ones_like( self.fake_dis_logits)))
self.mse_loss = tf.losses.mean_squared_error(self.fake_y,