GAN生成式对抗网络(四)——SRGAN超高分辨率图片重构

本文介绍了基于GAN的SRGAN模型,用于将低分辨率图片转换为高清图像。作者分享了实际效果和存在的问题,如颜色差异,并提出了增加颜色判别器的解决方案。内容涵盖GAN原理、SRGAN的模块关系、论文模型图和代码解析,同时提供了查看效果的工具函数。
摘要由CSDN通过智能技术生成

论文pdf 地址:https://arxiv.org/pdf/1609.04802v1.pdf

我的实际效果

清晰度距离我的期待有距离。
颜色上面存在差距。
解决想法
增加一个颜色判别器。将颜色值反馈给生成器

1545753-20181128120704364-961028862.png

srgan论文是建立在gan基础上的,利用gan生成式对抗网络,将图片重构为高清分辨率的图片。
github上有开源的srgan项目。由于开源者,开发时考虑的问题更丰富,技巧更为高明,导致其代码都比较难以阅读和理解。
在为了充分理解这个论文。这里结合论文,开源代码,和自己的理解重新写了个srgan高清分辨率模型。

GAN原理

在一个不断提高判断能力的判断器的持续反馈下,不断改善生成器的生成参数,直到生成器生成的结果能够通过判断器的判断。(见本博客其他文章)

SRGAN用到的模块,及其关系

损失值,根据的这个关系结构计算的。
1545753-20181127163533885-1386223271.png
注意:vgg19是使用已经训练好的模型,这里只是拿来提取特征使用,

对于生成器,根据三个运算结果数据,进行随机梯度的优化调整
①判定器生成数据的鉴定结果
②vgg19的特征比较情况
③生成图形与理想图形的mse差距

论文中,生成器和判别器的模型图

1545753-20181127170109742-1985386475.png
生成器结构为:一层卷积,16层残差卷积,再将第一层卷积结果+16层残差结,卷积+2倍反卷积,卷积+2倍反卷积,tanh缩放,产生生成结果。
判别器结构为:8层卷积+reshape,全连接。(论文中,用了两层。我这里只用了一层全连接,参数量太大,我6G 的gpu内存不够用)
vgg19结构:在vgg19的第四层,返回获取到的特征结果,进行MSE对比
注意:BN处理,leaky relu等等处理技巧

代码解释

import numpy as np
import os
import tensorlayer as tl
import tensorflow as tf

#获取vgg9.npy中vgg19的参数, 
vgg19_npy_path = "./vgg19.npy"
if not os.path.isfile(vgg19_npy_path):
    print("Please download vgg19.npz from : https://github.com/machrisaa/tensorflow-vgg")
    exit()
npz = np.load(vgg19_npy_path, encoding='latin1').item()
w_params = []
b_params = []
for val in sorted(npz.items()):
    W = np.asarray(val[1][0])
    b = np.asarray(val[1][1])
    # print("  Loading %s: %s, %s" % (val[0], W.shape, b.shape))
    w_params.append(W, )
    b_params.extend(b)


#tensorlayer加载图片时,用于处理图片。随机获取图片中 192*192的矩阵, 内存不足时,可以优化这里
def crop_sub_imgs_fn(x, is_random=True):
    x = tl.prepro.crop(x, wrg=192, hrg=192, is_random=is_random)
    x = x / (255. / 2.)
    x = x - 1.
    return x
#resize矩阵 内存不足时,可以优化这里
def downsample_fn(x):
    x = tl.prepro.imresize(x, size=[48, 48], interp='bicubic', mode=None)
    x = x / (255. / 2.)
    x = x - 1.
    return x

# 参数
config = {
    "epoch": 5,
}

# 内存不够时,可以减小这个
batch_size = 10 


class SRGAN(object):
    def __init__(self):
        # with tf.device('/gpu:0'):
        #占位变量,存储需要重构的图片
        self.x = tf.placeholder(tf.float32, shape=[batch_size, 48, 48, 3], name='train_bechanged')
        #占位变量,存储需要学习的理想中的图片
        self.y = tf.placeholder(tf.float32, shape=[batch_size, 192, 192, 3], name='train_target')
        self.init_fake_y = self.generator(self.x)  # 预训练时生成的假照片
        self.fake_y = self.generator(self.x, reuse=True)  # 全部训练时生成的假照片

         #占位变量,存储需要重构的测试图片
        self.test_x = tf.placeholder(tf.float32, shape=[1, None, None, 3], name='test_generator')
        #占位变量,存储重构后的测试图片
        self.test_fake_y = self.generator(self.test_x, reuse=True)  # 生成的假照片

        #占位变量,将生成图片resize
        self.fake_y_vgg = tf.image.resize_images(
            self.fake_y, size=[224, 224], method=0,
            align_corners=False)
         #占位变量,将理想图片resize
        self.real_y_vgg = tf.image.resize_images(
            self.y, size=[224, 224], method=0,
            align_corners=False)
        #提取伪造图片的特征
        self.fake_y_feature = self.vgg19(self.fake_y_vgg)  # 假照片的特征值
        #提取理想图片的特征
        self.real_y_feature = self.vgg19(self.real_y_vgg, reuse=True)  # 真照片的特征值

        # self.pre_dis_logits = self.discriminator(self.fake_y)  # 判别器生成的预测照片的判别值
        self.fake_dis_logits = self.discriminator(self.fake_y, reuse=False)  # 判别器生成的假照片的判别值
        self.real_dis_logits = self.discriminator(self.y, reuse=True)  # 判别器生成的假照片的判别值

        # 预训练时,判别器的优化根据值
        self.init_mse_loss = tf.losses.mean_squared_error(self.init_fake_y, self.y)

        # 关于判别器的优化根据值
        self.D_loos = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.real_dis_logits,
                                                                             labels=tf.ones_like(
                                                                                 self.real_dis_logits))) + \
                      tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.fake_dis_logits,
                                                                             labels=tf.zeros_like(
                                                                                 self.fake_dis_logits)))

        # 伪造数据判别器的判断情况,生成与目标图像的差距,生成特征与理想特征的差距
        self.D_loos_Ge = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.fake_dis_logits, labels=tf.ones_like( self.fake_dis_logits)))
        self.mse_loss = tf.losses.mean_squared_error(self.fake_y, 
  • 0
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值