IRN的PyTorch实现+使用预训练模型将LR图像变成HR图像

 

IRN的PyTorch实现+使用预训练模型将LR图像变成HR图像

文章:https://arxiv.org/abs/2005.05650

源码:https://github.com/pkuxmq/Invertible-Image-Rescaling#invertible-image-rescaling

1 文章简述

高分辨率的数字图像通常被缩小以适应不同的显示屏或节省存储和带宽的成本,同时后期放大被用于恢复原始分辨率或是缩放图像中的细节。但是,由于高频信息的丢失,典型的图像下采样是一种非单射映射,这就导致了逆上采样过程的不适定性问题,并给低分辨率图像下采样后的细节恢复带来了很大的挑战。单纯用图像超分辨率方法进行缩放,恢复效果不理想。在这项工作中,我们提出通过从一个新的角度,即可逆的双目标变换,来解决这一问题,它可以很大程度上缓解图像缩放的不适定性。我们开发了一个具有精心设计的框架和目标的 Invertible Rescaling Net (IRN) 模型,以产生视觉上令人满意的低分辨率图像,同时在降低尺度的过程中,利用一个潜在变量在指定的分布下捕获丢失信息的分布。这样,通过网络将随机绘制的潜在变量与低分辨率图像反向传递,使上采样变得可控制。实验结果表明,我们的模型比现有的方法在定量和定性评估将下采样图像放大重建方面都有显著的改进。

我们的贡献总结如下:

  • 据我们所知,所提出的 IRN 是第一次尝试用一个可逆(即双射)变换来建模图像的下采样和上采样,这是一对相互逆的任务。基于这种可逆性,我们提出的 IRN 可以很大程度上减轻由缩小的 LR 图像进行图像放大重建的病态性质。
  • 我们提出了一种新的模型设计和有效的训练目标,以加强潜在变量 z,在下采样方向上嵌入丢失的高频信息,以服从一个简单的个例不确定分布。这使得基于从特定分布中抽取的有价值的 z 样本的上采样有效。
  • 与目前最先进的 下采样-SR 和 编解码 方法相比,提出的 IRN 可以显著提高从缩小的 LR 图像进行放大重建的性能。同时,IRN的参数量也大大减少,表明了新算法的轻量化和高效性。

 

2 训练

在 options/train/ 中修改 “.yml” 后缀的配置文件,添加训练集和验证集的 HR 图像与 LR 图像的存放路径。如果没有 LR 图像,可以只添加 HR 图像的存放路径,训练时程序会自动生成对应的 LR 图像。

python train.py -opt options/train/train_IRN_x4.yml

 

3 测试

在 options/test/ 中修改 “.yml” 后缀的配置文件,添加测试集的 HR 图像与 LR 图像的存放路径。如果没有 LR 图像,可以只添加 HR 图像的存放路径。下载预训练模型(Google DriveBaidu Drive(code: lukj)) ,并添加路径到 path 中。

python test.py -opt options/test/test_IRN_x4.yml

 

4 LR->HR

1、在 options/test/ 中修改 “.yml” 后缀的配置文件,在 datasets 中添加自己的测试集,设置 LR 图像的存放路径。

python lr_to_hr.py -opt options/test/test_IRN_x4.yml

2、lr_to_hr.py 文件,如果设置了 ground truth(GT) 图像,即 HR 图像的存放路径,可以取消注释部分,计算生成的 SR 图像与 GT 图像的 PSNR。

import os.path as osp
import logging
import time
import argparse
import torch
from collections import OrderedDict

import numpy as np
import options.options as option
import utils.util as util
import data.util as dutil
from data.util import bgr2ycbcr
from data import create_dataset, create_dataloader
from models import create_model


def demo(opt):
    util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root'
                 and 'pretrain_model' not in key and 'resume' not in key))
    util.setup_logger('base', opt['path']['log'], 'test_' + opt['name'], level=logging.INFO, screen=True, tofile=True)
    logger = logging.getLogger('base')
    logger.info(option.dict2str(opt))

    model = create_model(opt)

    #### Create test data
    for phase, dataset_opt in sorted(opt['datasets'].items()):
        test_set_name = dataset_opt['name']
        dataset_dir = osp.join(opt['path']['results_root'], test_set_name)
        util.mkdir(dataset_dir)

        # test_results = OrderedDict()
        # test_results['psnr'] = []
        # test_results['ssim'] = []
        # test_results['psnr_y'] = []
        # test_results['ssim_y'] = []

        paths_LQ, _ = dutil.get_image_paths(opt['data_type'], dataset_opt['dataroot_LQ'])
        # paths_GT, _ = dutil.get_image_paths(opt['data_type'], dataset_opt['dataroot_LQ'].replace('LR', 'HR'))
        for i in range(len(paths_LQ)):
            LQ_path = paths_LQ[i]
            # GT_path = paths_GT[i]
            img_LQ = dutil.read_img(None, LQ_path, None)
            # img_GT = dutil.read_img(None, GT_path, None)
            img_name = osp.splitext(osp.basename(LQ_path))[0]
            img_name = img_name.split('_')[0]

            if opt['color']:
                img_LQ = dutil.channel_convert(img_LQ.shape[2], opt['color'], [img_LQ])[0]
                # img_GT = dutil.channel_convert(img_GT.shape[2], opt['color'], [img_GT])[0]
            if img_LQ.shape[2] == 3:
                img_LQ = img_LQ[:, :, [2, 1, 0]]
                # img_GT = img_GT[:, :, [2, 1, 0]]
            img_LQ = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LQ, (2, 0, 1)))).float().unsqueeze(0)
            # img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float().unsqueeze(0)

            test_start_time = time.time()
            SR_img = model.upscale(img_LQ, opt['scale'])

            sr_img = util.tensor2img(SR_img.detach()[0].float().cpu())
            lrgt_img = util.tensor2img(img_LQ.detach()[0].float().cpu())
            # gt_img = util.tensor2img(img_GT.detach()[0].float().cpu())

            suffix = opt['suffix']
            if suffix:
                save_img_path = osp.join(dataset_dir, img_name + suffix + '_SR.png')
            else:
                save_img_path = osp.join(dataset_dir, img_name + '_SR.png')
            util.save_img(sr_img, save_img_path)

            if suffix:
                save_img_path = osp.join(dataset_dir, img_name + suffix + '_LR_ref.png')
            else:
                save_img_path = osp.join(dataset_dir, img_name + '_LR_ref.png')
            util.save_img(lrgt_img, save_img_path)

            # if suffix:
            #     save_img_path = osp.join(dataset_dir, img_name + suffix + '_GT_ref.png')
            # else:
            #     save_img_path = osp.join(dataset_dir, img_name + '_GT_ref.png')
            # util.save_img(gt_img, save_img_path)

            test_end_time = time.time()
            logger.info('test images in [{:s}]: [{:d}] - image name: {:s}; test time: {:.6f};'
                        .format(test_set_name, i + 1, img_name, test_end_time - test_start_time))

            # calculate PSNR and SSIM
            # gt_img = gt_img / 255.
            # sr_img = sr_img / 255.

            # crop_border = opt['crop_border'] if opt['crop_border'] else opt['scale']
            # if crop_border == 0:
            #     cropped_sr_img = sr_img
            #     cropped_gt_img = gt_img
            # else:
            #     cropped_sr_img = sr_img[crop_border:-crop_border, crop_border:-crop_border, :]
            #     cropped_gt_img = gt_img[crop_border:-crop_border, crop_border:-crop_border, :]

            # psnr = util.calculate_psnr(cropped_sr_img * 255, cropped_gt_img * 255)
            # ssim = util.calculate_ssim(cropped_sr_img * 255, cropped_gt_img * 255)
            # test_results['psnr'].append(psnr)
            # test_results['ssim'].append(ssim)

            # if gt_img.shape[2] == 3:  # RGB image
            #     sr_img_y = bgr2ycbcr(sr_img, only_y=True)
            #     gt_img_y = bgr2ycbcr(gt_img, only_y=True)
            #     if crop_border == 0:
            #         cropped_sr_img_y = sr_img_y
            #         cropped_gt_img_y = gt_img_y
            #     else:
            #         cropped_sr_img_y = sr_img_y[crop_border:-crop_border, crop_border:-crop_border]
            #         cropped_gt_img_y = gt_img_y[crop_border:-crop_border, crop_border:-crop_border]
            #     psnr_y = util.calculate_psnr(cropped_sr_img_y * 255, cropped_gt_img_y * 255)
            #     ssim_y = util.calculate_ssim(cropped_sr_img_y * 255, cropped_gt_img_y * 255)
            #     test_results['psnr_y'].append(psnr_y)
            #     test_results['ssim_y'].append(ssim_y)
            
            #     logger.info('{:10s} - PSNR: {:.6f} dB; SSIM: {:.6f}; PSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}.\n'.
            #                 format(img_name, psnr, ssim, psnr_y, ssim_y))
            # else:
            #     logger.info('{:10s} - PSNR: {:.6f} dB; SSIM: {:.6f}.\n'
            #                 .format(img_name, psnr, ssim))

        # # Average PSNR/SSIM results
        # ave_psnr = sum(test_results['psnr']) / len(test_results['psnr'])
        # ave_ssim = sum(test_results['ssim']) / len(test_results['ssim'])
        # logger.info('----Average PSNR/SSIM results for {}----\tPSNR: {:.6f} db; SSIM: {:.6f}.\n'
        #             .format(test_set_name, ave_psnr, ave_ssim))

        # if test_results['psnr_y'] and test_results['ssim_y']:
        #     ave_psnr_y = sum(test_results['psnr_y']) / len(test_results['psnr_y'])
        #     ave_ssim_y = sum(test_results['ssim_y']) / len(test_results['ssim_y'])
            # logger.info('----Y channel, average PSNR/SSIM----\tPSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}.\n'
            #             .format(ave_psnr_y, ave_ssim_y))


if __name__ == '__main__':
    #### options
    parser = argparse.ArgumentParser()
    parser.add_argument('-opt', type=str, default='options/test/test_IRN_x4.yml', help='Path to options YMAL file.')
    opt = option.parse(parser.parse_args().opt, is_train=False)
    opt = option.dict_to_nonedict(opt)
    demo(opt)

3、在 models/IRN_model.py 文件中,修改 IRNModel 类中的 upscale 方法:

    def upscale(self, LR_img, scale, gaussian_scale=1):
        LR_img = LR_img.to(self.device)
        Lshape = LR_img.shape
        zshape = [Lshape[0], Lshape[1] * (scale**2 - 1), Lshape[2], Lshape[3]]

        self.netG.eval()
        with torch.no_grad():
            LR_img = self.Quantization(LR_img)
            y_ = torch.cat((LR_img, gaussian_scale * self.gaussian_batch(zshape)), dim=1)
            HR_img = self.netG(x=y_, rev=True)[:, :3, :, :]
        self.netG.train()

        return HR_img

 

  • 4
    点赞
  • 27
    收藏
    觉得还不错? 一键收藏
  • 17
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 17
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值