原文代码:GitHub - fuyongXu/SRCNN_Pytorch_1.0: The implemention of SRCNN by pytorch1.0
超分辨率卷积神经网络(SRCNN)是一种经典的用于图像超分辨率重建的深度学习模型。本文使用PyTorch框架进行复现。
一、下载数据集
在论文中的链接Learning a Deep Convolutional Network for Image Super-Resolution (cuhk.edu.hk)获取数据集,下载91image和set5等数据集,在放置代码的文件包中创建Test和Train两个文件夹,将下载的数据集放入对应文件中,91image作为Train集,set5或set14作为Test集。
二、将数据集转换为h5格式
在终端命令行使用以下命令行调用prepare.py,将数据集转换为h5格式。
python prepare.py --images-dir /content/drive/MyDrive/SRCNN_Pytorch_1.0-master/data/Train --output-path /content/drive/MyDrive/SRCNN_Pytorch_1.0-master/data/train.h5
python prepare.py --images-dir /content/drive/MyDrive/SRCNN_Pytorch_1.0-master/data/Test --output-path /content/drive/MyDrive/SRCNN_Pytorch_1.0-master/data/test.h5 --eval
三、模型训练
在终端命令行使用以下命令行调用Train.py,经过400轮训练之后,会得到一个好的参数权重,psnr最高,保存在best.pth中。
python train.py --train-file /content/drive/MyDrive/SRCNN_Pytorch_1.0-master/data/train.h5 --eval-file /content/drive/MyDrive/SRCNN_Pytorch_1.0-master/data/test.h5 --outputs-dir /content/drive/MyDrive/SRCNN_Pytorch_1.0-master/outputs
在本次训练中,代码添加了可视化,以下是添加了可视化的train.py代码
import argparse
import os
import copy
import torch
from torch import nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
from torch.utils.data.dataloader import DataLoader
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
from model import SRCNN
from datasets import TrainDataset, EvalDataset
from utils import AverageMeter, calc_psnr
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--train-file', type=str, required=True)
parser.add_argument('--eval-file', type=str, required=True)
parser.add_argument('--outputs-dir', type=str, required=True)
parser.add_argument('--scale', type=int, default=3)
parser.add_argument('--lr', type=float, default=1e-4)
parser.add_argument('--batch-size', type=int, default=16)
parser.add_argument('--num-workers', type=int, default=0)
parser.add_argument('--num-epochs', type=int, default=400)
parser.add_argument('--seed', type=int, default=123)
args = parser.parse_args()
args.outputs_dir = os.path.join(args.outputs_dir, 'x{}'.format(args.scale))
if not os.path.exists(args.outputs_dir):
os.makedirs(args.outputs_dir)
cudnn.benchmark = True
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
torch.manual_seed(args.seed)
model = SRCNN().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam([
{'params': model.conv1.parameters()},
{'params': model.conv2.parameters()},
{'params': model.conv3.parameters(), 'lr': args.lr*0.1}
], lr=args.lr)
train_dataset = TrainDataset(args.train_file)
train_dataloader = DataLoader(dataset=train_dataset,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.num_workers,
pin_memory=True,
drop_last=True)
eval_dataset = EvalDataset(args.eval_file)
eval_dataloader = DataLoader(dataset=eval_dataset, batch_size=1)
best_weights = copy.deepcopy(model.state_dict())
best_epoch = 0
best_psnr = 0.0
writer=SummaryWriter("logs")
for epoch in range(args.num_epochs):
model.train()
epoch_losses = AverageMeter()
with tqdm(total=(len(train_dataset) - len(train_dataset) % args.batch_size)) as t:
t.set_description('epoch:{}/{}'.format(epoch, args.num_epochs - 1))
for data in train_dataloader:
inputs, labels = data
inputs = inputs.to(device)
labels = labels.to(device)
preds = model(inputs)
loss = criterion(preds, labels)
epoch_losses.update(loss.item(), len(inputs))
optimizer.zero_grad()
loss.backward()
optimizer.step()
t.set_postfix(loss='{:.6f}'.format(epoch_losses.avg))
t.update(len(inputs))
torch.save(model.state_dict(), os.path.join(args.outputs_dir, 'epoch_{}.pth'.format(epoch)))
model.eval()
epoch_psnr = AverageMeter()
for data in eval_dataloader:
inputs, labels = data
inputs = inputs.to(device)
labels = labels.to(device)
with torch.no_grad():
preds = model(inputs).clamp(0.0, 1.0)
epoch_psnr.update(calc_psnr(preds, labels), len(inputs))
print('eval psnr: {:.2f}'.format(epoch_psnr.avg))
# 记录训练损失和评估准确率至 TensorBoard
writer.add_scalar("Loss/train", epoch_losses.avg, epoch)
writer.add_scalar("PSNR/eval", epoch_psnr.avg, epoch)
if epoch_psnr.avg > best_psnr:
best_epoch = epoch
best_psnr = epoch_psnr.avg
best_weights = copy.deepcopy(model.state_dict())
writer.close()
print('best epoch: {}, psnr: {:.2f}'.format(best_epoch, best_psnr))
torch.save(best_weights, os.path.join(args.outputs_dir, 'best.pth'))
在Colab下打开TensorBoard的可视化命令如下:
%reload_ext tensorboard
%tensorboard --logdir /content/drive/MyDrive/SRCNN_Pytorch_1.0-master/logs
四、测试图片
用训练得到的best.pth作为参数权重,命令如下:
python test.py --weights-file /content/drive/MyDrive/SRCNN_Pytorch_1.0-master/outputs/x3/best.pth --image-file /content/drive/MyDrive/SRCNN_Pytorch_1.0-master/data/Test/baby_GT.bmp --scale 3
在原来的代码中原图生成超分辨率图保存的路径需要存在,如果不存在则报错,因此我对代码进行了修改,代码如下:
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--weights-file', type=str, required=True)
parser.add_argument('--image-file', type=str, required=True)
parser.add_argument('--scale', type=int, default=3)
args = parser.parse_args()
cudnn.benchmark = True
device = torch.device('cuda: 0' if torch.cuda.is_available() else 'cpu')
model = SRCNN().to(device)
state_dict = model.state_dict()
for n, p in torch.load(args.weights_file, map_location=lambda storage, loc: storage).items():
if n in state_dict.keys():
state_dict[n].copy_(p)
else:
raise KeyError(n)
model.eval()
image = pil_image.open(args.image_file).convert('RGB')
image_width = (image.width // args.scale) * args.scale
image_height = (image.height // args.scale) * args.scale
image = image.resize((image_width, image_height), resample=pil_image.BICUBIC)
image = image.resize((image.width // args.scale, image.height // args.scale), resample=pil_image.BICUBIC)
image = image.resize((image.width * args.scale, image.height * args.scale), resample=pil_image.BICUBIC)
#image.save(args.image_file.replace('.', '_bicubic_x{}.'.format(args.scale)))
# 生成保存路径
dirname = os.path.dirname(args.image_file)
filename = os.path.basename(args.image_file)
basename, extension = os.path.splitext(filename)
save_path = os.path.join(dirname, basename + '_bicubic_x{}.bmp'.format(args.scale))
image.save(save_path)
image = np.array(image).astype(np.float32)
ycbcr = convert_rgb_to_ycbcr(image)
y = ycbcr[..., 0]
y /= 255.
y = torch.from_numpy(y).to(device)
y = y.unsqueeze(0).unsqueeze(0)
with torch.no_grad():
preds = model(y).clamp(0.0, 1.0)
psnr = calc_psnr(y, preds)
print('PSNR: {:.2f}'.format(psnr))
preds = preds.mul(255.0).cpu().numpy().squeeze(0).squeeze(0)
output = np.array([preds, ycbcr[..., 1], ycbcr[..., 2]]).transpose([1, 2, 0])
output = np.clip(convert_ycbcr_to_rgb(output), 0.0, 255.0).astype(np.uint8)
output = pil_image.fromarray(output)
#output.save(args.image_file.replace('.', '_srcnn_x{}.'.format(args.scale)))
# 生成保存路径
dirname = os.path.dirname(args.image_file)
filename = os.path.basename(args.image_file)
basename, extension = os.path.splitext(filename)
save_path = os.path.join(dirname, basename + '__srcnn_x{}.bmp'.format(args.scale))
output.save(save_path)
五、实验结果
本次训练轮数是230,以下是训练230轮次后,训练集的损失以及测试集的psnr的可视化图。
用set5图片进行测试,得到的psnr结果如下:
在源代码的基础小小修改后的代码见:Mxia-code/SRCNN_pytorch (github.com)