整体框架
SR,即super resolution,即超分辨率。CNN相对来说比较著名,就是卷积神经网络了。从名字可以看出,SRCNN是首个应用于超分辨领域的卷积神经网络,事实上也的确如此。
所谓超分辨率,就是把低分辨率(LR, Low Resolution)图片放大为高分辨率(HR, High Resolution)的过程。由于是开山之作,SRCNN相对比较简单,总共分三步
- 输入LR图像 X X X,经双三次(bicubic)插值,被放大成目标尺寸,得到 Y Y Y
- 通过三层卷积网络拟合非线性映射
- 输出HR图像结果 F ( Y ) F(Y) F(Y)
训练的目标损失是最小化SR图像 F ( Y ; θ ) F(Y;\theta) F(Y;θ)和原高分辨率图像 X X X像素差的均方误差
L ( θ ) = 1 n ∑ i = 1 n ∥ F ( Y i ; θ ) − X i ∥ 2 L(\theta)=\frac{1}{n}\sum^n_{i=1}\Vert F(Y_i;\theta)-X_i\Vert^2 L(θ)=n1i=1∑n∥F(Yi;θ)−Xi∥2
其中, n n n为训练样本数量,参数更新公式为
Δ i + 1 = 0.9 Δ i + η ∂ L ∂ W i l , W i + 1 l = W i l + Δ i + 1 \Delta_{i+1}=0.9\Delta_i+\eta\frac{\partial L}{\partial W^l_i},\quad W^l_{i+1}=W^l_i+\Delta_{i+1} Δi+1=0.9Δi+η∂Wil∂L,Wi+1l=Wil+Δi+1
网络模型
其网络结构如下
诚如前文所述,网络分为三个卷积层
- 维度是 1 × 9 × 9 × 64 1\times9\times9\times64 1×9×9×64,表示输入图像通道数为1,进行卷积运算的核尺寸为 9 × 9 9\times9 9×9,输出深度为64。
- 维度是 64 × 5 × 5 × 32 64\times5\times5\times32 64×5×5×32,64即上一层输出,32为下一层输出。
- 维度是 32 × 5 × 5 × 1 32\times5\times5\times1 32×5×5×1。其输出为单通道图像,与输入相同。
所以这个模型实现起来毫无难度
# models.py
class SRCNN(nn.Module):
def __init__(self, nChannel=1):
super(SRCNN,self).__init__()
self.conv1 = nn.Conv2d(nChannel, 64,
kernel_size=9, padding=9//2)
self.conv2 = nn.Conv2d(64, 32,
kernel_size=5, padding=5//2)
self.conv3 = nn.Conv2d(32, nChannel,
kernel_size=5, padding=5//2)
self.relu = nn.ReLU(inplace=True)
def forward(self,x):
x = self.relu(self.conv1(x))
x = self.relu(self.conv2(x))
x = self.conv3(x)
return x
数据集
训练数据集可手动生成,设放大倍数为scale
,考虑到原始数据未必会被scale
整除,所以要重新规划一下图像尺寸,所以训练数据集的生成分为三步:
- 将原始图像通过双三次插值重设尺寸,使之可被
scale
整除,作为高分辨图像数据HR - 将HR通过双三次插值压缩scale倍,为低分辨图像的原始数据
- 将低分辨图像通过双三次插值放大scale倍,与HR图像维度相等,作为低分辨图像数据LR
最后,可通过h5py
将训练数据分块并打包,其生成代码为
import h5py
import PIL.Image as pImg
def rgb2gray(img):
return 16. + (64.738 * img[:, :, 0] + 129.057 * img[:, :, 1] + 25.064 * img[:, :, 2]) / 256.
# imgPath为图像路径;h5Path为存储路径;scale为放大倍数
# pSize为patch尺寸; pStride为步长
def setTrianData(imgPath, h5Path, scale=3, pSize=33, pStride=14):
h5_file = h5py.File(h5Path, 'w')
lrPatches, hrPatches = [], [] #用于存储低分辨率和高分辨率的patch
for p in sorted(glob.glob(f'{imgPath}/*')):
hr = pImg.open(p).convert('RGB')
lrWidth, lrHeight = hr.width // scale, hr.height // scale
# width, height为可被scale整除的训练数据尺寸
width, height = lrWidth*scale, lrHeight*scale
hr = hr.resize((width, height), resample=pImg.BICUBIC)
lr = hr.resize((lrWidth, lrHeight), resample=pImg.BICUBIC)
lr = lr.resize((width, height), resample=pImg.BICUBIC)
hr = np.array(hr).astype(np.float32)
lr = np.array(lr).astype(np.float32)
hr = rgb2gray(hr)
lr = rgb2gray(lr)
# 将数据分割
for i in range(0, height - pSize + 1, pStride):
for j in range(0, width - pSize + 1, pStride):
lrPatches.append(lr[i:i + pSize, j:j + pSize])
hrPatches.append(hr[i:i + pSize, j:j + pSize])
h5_file.create_dataset('lr', data=np.array(lrPatches))
h5_file.create_dataset('hr', data=np.array(hrPatches))
h5_file.close()
以比较常见的T91数据集为例,通过上面的方法,可以得到一个181M的h5文件。
对于预测数据,也做同样处理。
在做好训练数据之后,需要为这些数据创建一个读取类,以便torch
中的DataLoader
调用,而DataLoader
中的内容则是Dataset
,所以新建的读取类需要继承Dataset
,并实现其__getitem__
和__len__
这两个成员方法。
这两个方法只是看上去吓人,但对Python稍有一点深入了解,就会知道__getitem__
是字典索引的方法,而__len__
则设定了len
函数的返回值。
import h5py
import numpy as np
from torch.utils.data import Dataset
class DataSet(Dataset):
def __init__(self, h5_file):
super(Dataset, self).__init__()
self.h5_file = h5_file
def __getitem__(self, idx):
with h5py.File(self.h5_file, 'r') as f:
return np.expand_dims(f['lr'][idx] / 255., 0), np.expand_dims(f['hr'][idx] / 255., 0)
def __len__(self):
with h5py.File(self.h5_file, 'r') as f:
return len(f['lr'])
训练
首先,训练需要一点准备工作,比如数据集准备好,相关的文件夹需要建好,建好模型之后,需要采用什么样的优化方式。训练设备是用cpu
还是cuda
,然后将数据集和模型装载到设备上。
数据准备
import os
import copy
import torch
from torch import nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
from torch.utils.data.dataloader import DataLoader
from models import SRCNN
trainFile = "91-image.h5"
evalFile = "Set5.h5"
cudnn.benchmark = True
# 设置训练设备 是CPU还是cuda
device = torch.device(
'cuda:0' if torch.cuda.is_available() else 'cpu')
# 装载训练数据
trainData = Dataset(trainFile)
trainLoader = DataLoader(dataset=trainData,
bSize=bSize,
shuffle=True, # 表示打乱样本
num_workers=nWorker, # 线程数
pin_memory=True, # 方便载入CUDA
drop_last=True)
# 装载预测数据
evalDatas = Dataset(evalFile)
evalLoader = DataLoader(dataset=evalDatas, bSize=1)
模型准备
# 模型和设备
lr = 1e-4 #学习率
torch.manual_seed(seed) #设置随机数种子
model = SRCNN().to(device) #将模型载入设备
criterion = nn.MSELoss() #设置损失函数
optimizer = optim.Adam([
{'params': model.conv1.parameters()},
{'params': model.conv2.parameters()},
{'params': model.conv3.parameters(), 'lr': lr * 0.1}
], lr=lr)
训练
outPath = "outputs"
scale = 3
bSize = 16
nEpoch = 400
nWorker = 8 #线程数
seed = 42 #随机数种子
def initPSNR():
return {'avg':0, 'sum':0, 'count':0}
def updatePSNR(psnr, val, n=1):
s = psnr['sum'] + val*n
c = psnr['count'] + n
return {'avg':s/c, 'sum':s, 'count':c}
bestWeights = copy.deepcopy(model.state_dict()) #最佳模型
bestEpoch = 0 #最佳训练结果
bestPSNR = 0.0 #最佳psnr
# 训练主循环
for epoch in range(nEpoch):
model.train()
epochLosses = initPSNR()
for data in trainLoader:
inputs, labels = data
inputs = inputs.to(device)
labels = labels.to(device)
preds = model(inputs)
loss = criterion(preds, labels)
epochLosses = updatePSNR(epochLosses,loss.item(), len(inputs))
optimizer.zero_grad() #清空梯度
loss.backward() #反向传播
optimizer.step() #根据梯度更新网络参数
print(f'{epochLosses['avg']:.6f}')
torch.save(model.state_dict(),
os.path.join(outPath, f'epoch_{epoch}.pth'))
model.eval() #取消dropout
psnr = AverageMeter()
for data in evalLoader:
inputs, labels = data
inputs = inputs.to(device)
labels = labels.to(device)
# 令reqires_grad自动设为False,关闭自动求导
# clamp将inputs归一化为0到1区间
with torch.no_grad():
preds = model(inputs).clamp(0.0, 1.0)
tmp_psnr = 10. * torch.log10(
1. / torch.mean((preds - labels) ** 2))
psnr = updatePSNR(psnr, tmp_psnr, len(inputs))
print(f'eval psnr: {psnr.avg:.2f}')
if psnr['avg'] > bestPSNR:
bestEpoch = epoch
bestPSNR = psnr['avg']
bestWeights = copy.deepcopy(model.state_dict())
print(f'best epoch: {bestEpoch}, psnr: {bestPSNR:.2f}')
torch.save(bestWeights, os.path.join(outPath, 'best.pth'))
最终的结果为