【超分辨率】SRCNN论文复现

原文代码:GitHub - fuyongXu/SRCNN_Pytorch_1.0: The implemention of SRCNN by pytorch1.0


超分辨率卷积神经网络(SRCNN)是一种经典的用于图像超分辨率重建的深度学习模型。本文使用PyTorch框架进行复现。

一、下载数据集

在论文中的链接Learning a Deep Convolutional Network for Image Super-Resolution (cuhk.edu.hk)获取数据集,下载91image和set5等数据集,在放置代码的文件包中创建Test和Train两个文件夹,将下载的数据集放入对应文件中,91image作为Train集,set5或set14作为Test集。

 

二、将数据集转换为h5格式

在终端命令行使用以下命令行调用prepare.py,将数据集转换为h5格式。 

python prepare.py --images-dir /content/drive/MyDrive/SRCNN_Pytorch_1.0-master/data/Train --output-path /content/drive/MyDrive/SRCNN_Pytorch_1.0-master/data/train.h5
python prepare.py --images-dir /content/drive/MyDrive/SRCNN_Pytorch_1.0-master/data/Test --output-path /content/drive/MyDrive/SRCNN_Pytorch_1.0-master/data/test.h5 --eval

三、模型训练

在终端命令行使用以下命令行调用Train.py,经过400轮训练之后,会得到一个好的参数权重,psnr最高,保存在best.pth中。

python train.py --train-file /content/drive/MyDrive/SRCNN_Pytorch_1.0-master/data/train.h5 --eval-file /content/drive/MyDrive/SRCNN_Pytorch_1.0-master/data/test.h5 --outputs-dir /content/drive/MyDrive/SRCNN_Pytorch_1.0-master/outputs

 在本次训练中,代码添加了可视化,以下是添加了可视化的train.py代码

import argparse
import os
import copy

import torch
from torch import nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
from torch.utils.data.dataloader import DataLoader
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter

from model import SRCNN
from datasets import TrainDataset, EvalDataset
from utils import AverageMeter, calc_psnr

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--train-file', type=str, required=True)
    parser.add_argument('--eval-file', type=str, required=True)
    parser.add_argument('--outputs-dir', type=str, required=True)
    parser.add_argument('--scale', type=int, default=3)
    parser.add_argument('--lr', type=float, default=1e-4)
    parser.add_argument('--batch-size', type=int, default=16)
    parser.add_argument('--num-workers', type=int, default=0)
    parser.add_argument('--num-epochs', type=int, default=400)
    parser.add_argument('--seed', type=int, default=123)
    args = parser.parse_args()

    args.outputs_dir = os.path.join(args.outputs_dir, 'x{}'.format(args.scale))

    if not os.path.exists(args.outputs_dir):
        os.makedirs(args.outputs_dir)

    cudnn.benchmark = True
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    torch.manual_seed(args.seed)

    model = SRCNN().to(device)
    criterion = nn.MSELoss()
    optimizer = optim.Adam([
        {'params': model.conv1.parameters()},
        {'params': model.conv2.parameters()},
        {'params': model.conv3.parameters(), 'lr': args.lr*0.1}
    ], lr=args.lr)

    train_dataset = TrainDataset(args.train_file)
    train_dataloader = DataLoader(dataset=train_dataset,
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True,
                                  drop_last=True)

    eval_dataset = EvalDataset(args.eval_file)
    eval_dataloader = DataLoader(dataset=eval_dataset, batch_size=1)

    best_weights = copy.deepcopy(model.state_dict())
    best_epoch = 0
    best_psnr = 0.0

    writer=SummaryWriter("logs")
    for epoch in range(args.num_epochs):
        model.train()
        epoch_losses = AverageMeter()

        with tqdm(total=(len(train_dataset) - len(train_dataset) % args.batch_size)) as t:
            t.set_description('epoch:{}/{}'.format(epoch, args.num_epochs - 1))

            for data in train_dataloader:
                inputs, labels = data

                inputs = inputs.to(device)
                labels = labels.to(device)

                preds = model(inputs)
                loss = criterion(preds, labels)

                epoch_losses.update(loss.item(), len(inputs))

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                t.set_postfix(loss='{:.6f}'.format(epoch_losses.avg))
                t.update(len(inputs))

        torch.save(model.state_dict(), os.path.join(args.outputs_dir, 'epoch_{}.pth'.format(epoch)))

        model.eval()
        epoch_psnr = AverageMeter()

        for data in eval_dataloader:
            inputs, labels = data

            inputs = inputs.to(device)
            labels = labels.to(device)

            with torch.no_grad():
                preds = model(inputs).clamp(0.0, 1.0)

            epoch_psnr.update(calc_psnr(preds, labels), len(inputs))

        print('eval psnr: {:.2f}'.format(epoch_psnr.avg))
        # 记录训练损失和评估准确率至 TensorBoard
        writer.add_scalar("Loss/train", epoch_losses.avg, epoch)
        writer.add_scalar("PSNR/eval", epoch_psnr.avg, epoch)

        if epoch_psnr.avg > best_psnr:
            best_epoch = epoch
            best_psnr = epoch_psnr.avg
            best_weights = copy.deepcopy(model.state_dict())
    
    writer.close()

    print('best epoch: {}, psnr: {:.2f}'.format(best_epoch, best_psnr))
    torch.save(best_weights, os.path.join(args.outputs_dir, 'best.pth'))

在Colab下打开TensorBoard的可视化命令如下:

%reload_ext tensorboard
%tensorboard --logdir /content/drive/MyDrive/SRCNN_Pytorch_1.0-master/logs

四、测试图片

用训练得到的best.pth作为参数权重,命令如下:

python test.py --weights-file /content/drive/MyDrive/SRCNN_Pytorch_1.0-master/outputs/x3/best.pth --image-file /content/drive/MyDrive/SRCNN_Pytorch_1.0-master/data/Test/baby_GT.bmp --scale 3

在原来的代码中原图生成超分辨率图保存的路径需要存在,如果不存在则报错,因此我对代码进行了修改,代码如下:

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--weights-file', type=str, required=True)
    parser.add_argument('--image-file', type=str, required=True)
    parser.add_argument('--scale', type=int, default=3)
    args = parser.parse_args()

    cudnn.benchmark = True
    device = torch.device('cuda: 0' if torch.cuda.is_available() else 'cpu')

    model = SRCNN().to(device)

    state_dict = model.state_dict()
    for n, p in torch.load(args.weights_file, map_location=lambda storage, loc: storage).items():
        if n in state_dict.keys():
            state_dict[n].copy_(p)
        else:
            raise KeyError(n)

    model.eval()

    image = pil_image.open(args.image_file).convert('RGB')

    image_width = (image.width // args.scale) * args.scale
    image_height = (image.height // args.scale) * args.scale
    image = image.resize((image_width, image_height), resample=pil_image.BICUBIC)
    image = image.resize((image.width // args.scale, image.height // args.scale), resample=pil_image.BICUBIC)
    image = image.resize((image.width * args.scale, image.height * args.scale), resample=pil_image.BICUBIC)
    #image.save(args.image_file.replace('.', '_bicubic_x{}.'.format(args.scale)))
    # 生成保存路径
    dirname = os.path.dirname(args.image_file)
    filename = os.path.basename(args.image_file)
    basename, extension = os.path.splitext(filename)
    save_path = os.path.join(dirname, basename + '_bicubic_x{}.bmp'.format(args.scale))
    image.save(save_path)

    image = np.array(image).astype(np.float32)
    ycbcr = convert_rgb_to_ycbcr(image)

    y = ycbcr[..., 0]
    y /= 255.
    y = torch.from_numpy(y).to(device)
    y = y.unsqueeze(0).unsqueeze(0)

    with torch.no_grad():
        preds = model(y).clamp(0.0, 1.0)

    psnr = calc_psnr(y, preds)
    print('PSNR: {:.2f}'.format(psnr))

    preds = preds.mul(255.0).cpu().numpy().squeeze(0).squeeze(0)

    output = np.array([preds, ycbcr[..., 1], ycbcr[..., 2]]).transpose([1, 2, 0])
    output = np.clip(convert_ycbcr_to_rgb(output), 0.0, 255.0).astype(np.uint8)
    output = pil_image.fromarray(output)
    #output.save(args.image_file.replace('.', '_srcnn_x{}.'.format(args.scale)))
    # 生成保存路径
    dirname = os.path.dirname(args.image_file)
    filename = os.path.basename(args.image_file)
    basename, extension = os.path.splitext(filename)
    save_path = os.path.join(dirname, basename + '__srcnn_x{}.bmp'.format(args.scale))
    output.save(save_path)

五、实验结果

本次训练轮数是230,以下是训练230轮次后,训练集的损失以及测试集的psnr的可视化图。

 用set5图片进行测试,得到的psnr结果如下:


在源代码的基础小小修改后的代码见:Mxia-code/SRCNN_pytorch (github.com)

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值