【超分辨率】SRCNN论文复现

原文代码:GitHub - fuyongXu/SRCNN_Pytorch_1.0: The implemention of SRCNN by pytorch1.0


超分辨率卷积神经网络(SRCNN)是一种经典的用于图像超分辨率重建的深度学习模型。本文使用PyTorch框架进行复现。

一、下载数据集

在论文中的链接Learning a Deep Convolutional Network for Image Super-Resolution (cuhk.edu.hk)获取数据集,下载91image和set5等数据集,在放置代码的文件包中创建Test和Train两个文件夹,将下载的数据集放入对应文件中,91image作为Train集,set5或set14作为Test集。

 

二、将数据集转换为h5格式

在终端命令行使用以下命令行调用prepare.py,将数据集转换为h5格式。 

python prepare.py --images-dir /content/drive/MyDrive/SRCNN_Pytorch_1.0-master/data/Train --output-path /content/drive/MyDrive/SRCNN_Pytorch_1.0-master/data/train.h5
python prepare.py --images-dir /content/drive/MyDrive/SRCNN_Pytorch_1.0-master/data/Test --output-path /content/drive/MyDrive/SRCNN_Pytorch_1.0-master/data/test.h5 --eval

三、模型训练

在终端命令行使用以下命令行调用Train.py,经过400轮训练之后,会得到一个好的参数权重,psnr最高,保存在best.pth中。

python train.py --train-file /content/drive/MyDrive/SRCNN_Pytorch_1.0-master/data/train.h5 --eval-file /content/drive/MyDrive/SRCNN_Pytorch_1.0-master/data/test.h5 --outputs-dir /content/drive/MyDrive/SRCNN_Pytorch_1.0-master/outputs

 在本次训练中,代码添加了可视化,以下是添加了可视化的train.py代码

import argparse
import os
import copy

import torch
from torch import nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
from torch.utils.data.dataloader import DataLoader
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter

from model import SRCNN
from datasets import TrainDataset, EvalDataset
from utils import AverageMeter, calc_psnr

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--train-file', type=str, required=True)
    parser.add_argument('--eval-file', type=str, required=True)
    parser.add_argument('--outputs-dir', type=str, required=True)
    parser.add_argument('--scale', type=int, default=3)
    parser.add_argument('--lr', type=float, default=1e-4)
    parser.add_argument('--batch-size', type=int, default=16)
    parser.add_argument('--num-workers', type=int, default=0)
    parser.add_argument('--num-epochs', type=int, default=400)
    parser.add_argument('--seed', type=int, default=123)
    args = parser.parse_args()

    args.outputs_dir = os.path.join(args.outputs_dir, 'x{}'.format(args.scale))

    if not os.path.exists(args.outputs_dir):
        os.makedirs(args.outputs_dir)

    cudnn.benchmark = True
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    torch.manual_seed(args.seed)

    model = SRCNN().to(device)
    criterion = nn.MSELoss()
    optimizer = optim.Adam([
        {'params': model.conv1.parameters()},
        {'params': model.conv2.parameters()},
        {'params': model.conv3.parameters(), 'lr': args.lr*0.1}
    ], lr=args.lr)

    train_dataset = TrainDataset(args.train_file)
    train_dataloader = DataLoader(dataset=train_dataset,
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True,
                                  drop_last=True)

    eval_dataset = EvalDataset(args.eval_file)
    eval_dataloader = DataLoader(dataset=eval_dataset, batch_size=1)

    best_weights = copy.deepcopy(model.state_dict())
    best_epoch = 0
    best_psnr = 0.0

    writer=SummaryWriter("logs")
    for epoch in range(args.num_epochs):
        model.train()
        epoch_losses = AverageMeter()

        with tqdm(total=(len(train_dataset) - len(train_dataset) % args.batch_size)) as t:
            t.set_description('epoch:{}/{}'.format(epoch, args.num_epochs - 1))

            for data in train_dataloader:
                inputs, labels = data

                inputs = inputs.to(device)
                labels = labels.to(device)

                preds = model(inputs)
                loss = criterion(preds, labels)

                epoch_losses.update(loss.item(), len(inputs))

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                t.set_postfix(loss='{:.6f}'.format(epoch_losses.avg))
                t.update(len(inputs))

        torch.save(model.state_dict(), os.path.join(args.outputs_dir, 'epoch_{}.pth'.format(epoch)))

        model.eval()
        epoch_psnr = AverageMeter()

        for data in eval_dataloader:
            inputs, labels = data

            inputs = inputs.to(device)
            labels = labels.to(device)

            with torch.no_grad():
                preds = model(inputs).clamp(0.0, 1.0)

            epoch_psnr.update(calc_psnr(preds, labels), len(inputs))

        print('eval psnr: {:.2f}'.format(epoch_psnr.avg))
        # 记录训练损失和评估准确率至 TensorBoard
        writer.add_scalar("Loss/train", epoch_losses.avg, epoch)
        writer.add_scalar("PSNR/eval", epoch_psnr.avg, epoch)

        if epoch_psnr.avg > best_psnr:
            best_epoch = epoch
            best_psnr = epoch_psnr.avg
            best_weights = copy.deepcopy(model.state_dict())
    
    writer.close()

    print('best epoch: {}, psnr: {:.2f}'.format(best_epoch, best_psnr))
    torch.save(best_weights, os.path.join(args.outputs_dir, 'best.pth'))

在Colab下打开TensorBoard的可视化命令如下:

%reload_ext tensorboard
%tensorboard --logdir /content/drive/MyDrive/SRCNN_Pytorch_1.0-master/logs

四、测试图片

用训练得到的best.pth作为参数权重,命令如下:

python test.py --weights-file /content/drive/MyDrive/SRCNN_Pytorch_1.0-master/outputs/x3/best.pth --image-file /content/drive/MyDrive/SRCNN_Pytorch_1.0-master/data/Test/baby_GT.bmp --scale 3

在原来的代码中原图生成超分辨率图保存的路径需要存在,如果不存在则报错,因此我对代码进行了修改,代码如下:

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--weights-file', type=str, required=True)
    parser.add_argument('--image-file', type=str, required=True)
    parser.add_argument('--scale', type=int, default=3)
    args = parser.parse_args()

    cudnn.benchmark = True
    device = torch.device('cuda: 0' if torch.cuda.is_available() else 'cpu')

    model = SRCNN().to(device)

    state_dict = model.state_dict()
    for n, p in torch.load(args.weights_file, map_location=lambda storage, loc: storage).items():
        if n in state_dict.keys():
            state_dict[n].copy_(p)
        else:
            raise KeyError(n)

    model.eval()

    image = pil_image.open(args.image_file).convert('RGB')

    image_width = (image.width // args.scale) * args.scale
    image_height = (image.height // args.scale) * args.scale
    image = image.resize((image_width, image_height), resample=pil_image.BICUBIC)
    image = image.resize((image.width // args.scale, image.height // args.scale), resample=pil_image.BICUBIC)
    image = image.resize((image.width * args.scale, image.height * args.scale), resample=pil_image.BICUBIC)
    #image.save(args.image_file.replace('.', '_bicubic_x{}.'.format(args.scale)))
    # 生成保存路径
    dirname = os.path.dirname(args.image_file)
    filename = os.path.basename(args.image_file)
    basename, extension = os.path.splitext(filename)
    save_path = os.path.join(dirname, basename + '_bicubic_x{}.bmp'.format(args.scale))
    image.save(save_path)

    image = np.array(image).astype(np.float32)
    ycbcr = convert_rgb_to_ycbcr(image)

    y = ycbcr[..., 0]
    y /= 255.
    y = torch.from_numpy(y).to(device)
    y = y.unsqueeze(0).unsqueeze(0)

    with torch.no_grad():
        preds = model(y).clamp(0.0, 1.0)

    psnr = calc_psnr(y, preds)
    print('PSNR: {:.2f}'.format(psnr))

    preds = preds.mul(255.0).cpu().numpy().squeeze(0).squeeze(0)

    output = np.array([preds, ycbcr[..., 1], ycbcr[..., 2]]).transpose([1, 2, 0])
    output = np.clip(convert_ycbcr_to_rgb(output), 0.0, 255.0).astype(np.uint8)
    output = pil_image.fromarray(output)
    #output.save(args.image_file.replace('.', '_srcnn_x{}.'.format(args.scale)))
    # 生成保存路径
    dirname = os.path.dirname(args.image_file)
    filename = os.path.basename(args.image_file)
    basename, extension = os.path.splitext(filename)
    save_path = os.path.join(dirname, basename + '__srcnn_x{}.bmp'.format(args.scale))
    output.save(save_path)

五、实验结果

本次训练轮数是230,以下是训练230轮次后,训练集的损失以及测试集的psnr的可视化图。

 用set5图片进行测试,得到的psnr结果如下:


在源代码的基础小小修改后的代码见:Mxia-code/SRCNN_pytorch (github.com)

  • 6
    点赞
  • 43
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
图像超分辨率重建是指将低分辨率图像通过算法处理,得到高分辨率图像的过程。以下是一个基于Python的图像超分辨率重建的简单实现: 首先,我们需要导入一些必要的库: ```python import numpy as np import cv2 from skimage.measure import compare_psnr ``` 然后,我们读取一张低分辨率的图像,并将其展示出来: ```python img_lr = cv2.imread('low_resolution_image.jpg') cv2.imshow('Low Resolution Image', img_lr) cv2.waitKey(0) cv2.destroyAllWindows() ``` 接着,我们使用双三次插值的方式将低分辨率图像放大到目标分辨率,并展示出来: ```python img_bicubic = cv2.resize(img_lr, None, fx=3, fy=3, interpolation=cv2.INTER_CUBIC) cv2.imshow('Bicubic Interpolation Image', img_bicubic) cv2.waitKey(0) cv2.destroyAllWindows() ``` 接下来,我们使用OpenCV中的超分辨率算法实现图像的超分辨率重建: ```python # 创建超分辨率算法对象 sr = cv2.dnn_superres.DnnSuperResImpl_create() # 选择算法模型 sr.readModel('EDSR_x3.pb') sr.setModel('edsr', 3) # 对低分辨率图像进行超分辨率重建 img_sr = sr.upsample(img_lr) # 展示结果 cv2.imshow('Super Resolution Image', img_sr) cv2.waitKey(0) cv2.destroyAllWindows() ``` 最后,我们计算超分辨率重建图像与原始高分辨率图像之间的PSNR值,并输出结果: ```python img_hr = cv2.imread('high_resolution_image.jpg') psnr = compare_psnr(img_hr, img_sr) print('PSNR:', psnr) ``` 这是一个简单的图像超分辨率重建Python实现。当然,实现一个高质量的图像超分辨率重建算法需要更加深入的研究和实践。
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值