SRGAN实现超分辨率图像重建之模型复现

1.论文介绍

1.1简介

论文名称《Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial》

Ledig C , Theis L , Huszar F , et al. Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network[J]. 2016.

该文针对传统超分辨方法中存在的结果过于平滑的问题,提出了结合最新的对抗网络的方法,得到了不错的效果。并且针对此网络结构,构建了自己的感知损失函数。

传统的超分辨率模型采用均方差作为损失函数,虽然可以获得很高的峰值信噪比,但是恢复出来的图像会丢失高频细节。

因此,该文提出了一种自己设计的感知损失函数:

1.2 content loss

像素级 MSE Loss 的计算为:

   这个是最经常使用的优化目标。但是,这种方式当取得较高的 PSNR的同时,MSE 优化问题导致缺乏 high-frequency content,这就会使得结果太过于平滑(overly smooth solutions)。因此在在 pre-trained 19-layer VGG network 的 ReLU activation layers 的基础上,定义了 VGG loss 

    其中,Wi,jWi,j  and Hi,jHi,j 表示了 VGG network 当中相应的 feature maps 的维度。

1.3 Adversarial Loss 

    在所有训练样本上,基于判别器的概率定义 generative loss :

    此处,D 是重构图像是 natural HR image 的概率。

  

1.4 Regulatization Loss 

    我们进一步的采用 基于 total variation 的正则化项来鼓励 spatially coherent solutions。正则化损失的定义为:

 

效果展示

3. 模型复现

3.1 整体结构

分为两部分,一部分为产生网络,一部分为鉴别网络,接下来分别对其进行复现。

3.2 Generator Network

首先由结构图可以看出,整体大致分为三部分。

第一部分input输入一张低分辨率的图像,并进行一次卷积和ReLU操作

  #第一部分,传入低分辨率图像
        LR_input = Input(shape=self.LR_shape)
        layer1 = Conv2D(64, kernel_size=3, strides=1, padding='same')(LR_input)
        layer1 = Activation('relu')(layer1)

第二部分会经过B个残差块,每个残差块的组成为两次卷积,两次BN层,一次ReLU,最后有一个残差边。

先构建一个残差块函数,方便使用

        def residual_block(input, filter):
            layer = Conv2D(filter, kernel_size=3, strides=1, padding='same')(input)
            # 衰减率暂时设0.8,不考虑性能,优先考虑回环检测需求
            layer = BatchNormalization(momentum=0.8)(layer)
            layer = Activation('relu')(layer)
            layer = Conv2D(filter, kernel_size=3, strides=1, padding='same')(layer)
            layer = BatchNormalization(momentum=0.8)(layer)
            layer = Add()([layer,input])
            return layer

经过b个残差结构块: 

  #第二部分,经过b个残差结构块
        layer2 = residual_block(layer1, 64)
        for _ in range(self.b_residual_blocks - 1):
            layer2 = residual_block(layer2, 64)

第三部分经过两次Deconv上采样和ReLU,再经过一次卷积,将图像扩大为原来的四倍来提高分辨率。

先构建解码函数:

  def deConv(input):
            layer = UpSampling2D(size=2)(input)
            layer = Conv2D(256, kernel_size=3, strides=1, padding='same')(layer)
            layer = Activation('relu')(layer)
            return layer
 #第三部分,上采样图像放大为原来的4倍
        layer3 = Conv2D(64, kernel_size=3, strides=1, padding='same')(layer2)
        layer3 = BatchNormalization(momentum=0.8)(layer3)
        layer3 = Add()([layer3,layer1])

        res1 = deConv(layer3)
        res2 = deConv(res1)
        genernator_HR = Conv2D(self.channels, kernel_size=9, strides=1, padding='same', activation='tanh')(res2)

        # 返回原低分辨率图像以及生成的伪高清图像
        return Model(LR_input, genernator_HR)

最终输出高分辨率的图像。

3.3 Discriminator Network

整体看还是分为三个大部分:

第一部分对input输入图像进行两次卷积和Leaky ReLU。

# 公用卷积块
        def conv_block(input, filter, strides=1, BN=True):
            block = Conv2D(filter, kernel_size=3, strides=strides, padding='same')(input)
            #参数慢慢调吧
            block = LeakyReLU(alpha=0.2)(block)
            if BN:
                block = BatchNormalization(momentum=0.8)(block)
            return block

        # 过滤器和步长参考的论文的结构图上的
        # 第一部分
        input = Input(shape=self.HR_shape)
        layer1 = conv_block(input=input, filter=64, BN=False)
        layer2 = conv_block(layer1,64,strides=2)

第二部分是经过6次卷积+Leaky ReLU+BN操作。

#第二部分
        layer3 = conv_block(layer2, 128)
        layer4 = conv_block(layer3, 128, strides=2)
        layer5 = conv_block(layer4, 256)
        layer6 = conv_block(layer5, 256, strides=2)
        layer7 = conv_block(layer6, 512)
        layer8 = conv_block(layer7, 512,strides=2)

第三部分是经过一个Dense(1024)+Leaky ReLU+Dense(1)+Sigmoid,然后对图像进行输出。

 #第三部分
        layer9 = Dense(1024)(layer8)
        layer10 = LeakyReLU(alpha=0.2)(layer9)
        res = Dense(1,activation='sigmoid')(layer10)
        return Model(input, res)

3.4 VGGNet 

利用VGG19的第9层特征层,虽然代码是from keras.applications import VGG19导入的VGG19,但是预训练权重最好还是先去官网下载好,否则第一次下载的太慢了。

贴一个百度云:https://pan.baidu.com/s/1cJabHCqKNdIqNOxHZnXvmg,提取码:vltq

    #采用VGG19的第九层特征层
    def VGG(self):
        vgg = VGG19(weights = "imagenet")
        vgg.outputs = [vgg.layers[9].output]
        img = Input(shape = self.HR_shape)
        img_features = vgg(img)
        return Model(img, img_features)

至此,论文模型就算复现完成,后续将完成模型训练,预测。

4.训练过程

类初始化:

高清图像的大小为512*512*3,低分辨率图片为128*128*3,四倍关系。

 def __init__(self):
        #低分辨率初始化
        self.LR_height = 128
        self.LR_width = 128
        self.channels = 3
        self.LR_shape = (self.LR_height, self.LR_width, self.channels)

        #高分辨率初始化
        self.HR_height = self.LR_height * 4
        self.HR_width = self.LR_width * 4
        self.HR_shape = (self.HR_height, self.HR_width, self.channels)

        #b个残差块
        self.b_residual_blocks = 16

        #优化器,初始学习率0.002
        optimizer = Adam(0.0002, 0.9)
        # 数据集
        self.datasets_name = "DIV"
        self.dataProcess = dataProcess(data_name=self.datasets_name, img=(self.HR_height, self.HR_width))

        self.vgg = self.VGG()
        self.vgg.trainable = False

        #建立生成网络
        self.generator = self.generatorNet()
        self.generator.summary()

        # 建立鉴别网络
        patch = int(self.HR_height / 2**4)
        self.discriminator_patch = (patch, patch, 1)
        self.discriminator = self.discriminatorNet()
        self.discriminator.compile(loss='binary_crossentropy', optimizer=optimizer,metrics=['accuracy'])
        self.discriminator.summary()

        LR_img = Input(shape=self.LR_shape)
        gen_HR = self.generator(LR_img)
        gen_HR_features = self.vgg(gen_HR)

        self.discriminator.trainable = False
        validity = self.discriminator(gen_HR)
        self.combined = Model(LR_img, [validity, gen_HR_features])
        self.combined.compile(loss=['binary_crossentropy', 'mse'], loss_weights=[5e-1,1],optimizer=optimizer)

先对数据集进行预处理:

class dataProcess():
    def __init__(self, data_name, img=(128, 128)):
        self.data_name = data_name
        self.img = img

    def process(self, batch_size=1, data_type=False):
        dataPath = glob('./datasets/%s/train/*' % (self.data_name))
        images = np.random.choice(dataPath, size=batch_size)

        img_HR = []
        img_LR = []
        for i in images:
            img = self.imread(i)
            height, width = self.img
            # 缩小4倍
            L_height, L_width = int(height / 4), int(width / 4)

            img_H = scipy.misc.imresize(img, self.img)
            img_L = scipy.misc.imresize(img, (L_height, L_width))

            if not data_type and np.random.random() < 0.5:
                img_H = np.fliplr(img_H)
                img_L = np.fliplr(img_L)

            img_HR.append(img_H)
            img_LR.append(img_L)

        img_HR = np.array(img_HR) / 127.5 - 1
        img_LR = np.array(img_LR) / 127.5 - 1

        return img_HR, img_LR

    def imread(self, path):
        return scipy.misc.imread(path, mode='RGB').astype(np.float)

训练函数:

先将图片传入产生网络生成伪高清图片,然后分别计算伪高清图像和原始高清图像的 loss值,求他们的平均值。然后将两个图像传入VGG网络,利用VGG第9层网络提取图片的特征向量,并计算loss值。

 def train(self, epochs, init_epoch=0, batch_size=1, example=50):
        begin_time = datetime.datetime.now()
        if init_epoch != 0:
            self.generator.load_weights("weights/%s/gen_epoch%d.h5" % (self.datasets_name, init_epoch), skip_mismatch=True)
            self.discriminator.load_weights("weights/%s/dis_epoch%d.h5" % (self.dataset_name, init_epoch), skip_mismatch=True)

        for epoch in range(init_epoch, epochs):
            self.learining([self.combined, self.discriminator], epoch)
            HR_img, LR_img = self.dataProcess.process(batch_size)
            gen_HR = self.generator.predict(LR_img)
            valid = np.ones((batch_size,) + self.discriminator_patch)
            gen = np.zeros((batch_size,) + self.discriminator_patch)
            origin_loss = self.discriminator.train_on_batch(HR_img, valid)
            gen_loss = self.discriminator.train_on_batch(gen_HR, gen)
            d_loss = 0.5 * np.add(origin_loss, gen_loss)

            HR_img, LR_img = self.dataProcess.process(batch_size)
            valid = np.ones((batch_size,) + self.discriminator_patch)
            img_features = self.vgg.predict(HR_img)
            g_loss = self.combined.train_on_batch(LR_img, [valid, img_features])
            print(d_loss, g_loss)
            end_time = datetime.datetime.now()
            time = end_time - begin_time
            print("[Epoch %d/%d] [D loss: %f, acc: %3d%%] [G loss: %05f, feature loss: %05f] time: %s " \
                  % (epoch,
                     epochs,
                     d_loss[0],
                     100 * d_loss[1],
                     g_loss[1],
                     g_loss[2],
                     time))

            if epoch % example == 0:
                self.restore(epoch)
                if epoch % 500 == 0 and epoch != init_epoch:
                    # 500代保存一次
                    os.makedirs('weights/%s' % self.datasets_name, exist_ok=True)
                    self.generator.save_weights("weights/%s/gen_epoch%d.h5" % (self.datasets_name, epoch))
                    self.dicriminator.save_weights("weights/%s/dis_epoch%d.h5" % (self.datasets_name, epoch))

5. 效果展示

低分辨率图为128*128,伪高清图为512*512

低分辨率图:

生成的伪高清图:

 

 

项目地址:vanbou/TensorFlow_SRGAN (github.com)

 

  • 3
    点赞
  • 42
    收藏
    觉得还不错? 一键收藏
  • 16
    评论
图像超分辨率重建是指将低分辨率图像通过算法处理,得到高分辨率图像的过程。以下是一个基于Python的图像超分辨率重建的简单实现: 首先,我们需要导入一些必要的库: ```python import numpy as np import cv2 from skimage.measure import compare_psnr ``` 然后,我们读取一张低分辨率的图像,并将其展示出来: ```python img_lr = cv2.imread('low_resolution_image.jpg') cv2.imshow('Low Resolution Image', img_lr) cv2.waitKey(0) cv2.destroyAllWindows() ``` 接着,我们使用双三次插值的方式将低分辨率图像放大到目标分辨率,并展示出来: ```python img_bicubic = cv2.resize(img_lr, None, fx=3, fy=3, interpolation=cv2.INTER_CUBIC) cv2.imshow('Bicubic Interpolation Image', img_bicubic) cv2.waitKey(0) cv2.destroyAllWindows() ``` 接下来,我们使用OpenCV中的超分辨率算法实现图像的超分辨率重建: ```python # 创建超分辨率算法对象 sr = cv2.dnn_superres.DnnSuperResImpl_create() # 选择算法模型 sr.readModel('EDSR_x3.pb') sr.setModel('edsr', 3) # 对低分辨率图像进行超分辨率重建 img_sr = sr.upsample(img_lr) # 展示结果 cv2.imshow('Super Resolution Image', img_sr) cv2.waitKey(0) cv2.destroyAllWindows() ``` 最后,我们计算超分辨率重建图像与原始高分辨率图像之间的PSNR值,并输出结果: ```python img_hr = cv2.imread('high_resolution_image.jpg') psnr = compare_psnr(img_hr, img_sr) print('PSNR:', psnr) ``` 这是一个简单的图像超分辨率重建的Python实现。当然,实现一个高质量的图像超分辨率重建算法需要更加深入的研究和实践。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 16
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值