基于SRResNet的图像超分辨率重建
因为事务繁忙,所以博客好久都没有更新了,今天难得有空更新一下。
1. 任务描述
使用Pytorch实现SRResNet模型。
2. 知识准备
2.1 图像超分辨率
像超分辨率是指从低分辨率图像中恢复出自然、清晰的纹理,最终得到一张高分辨率图像,是图像增强领域中一个非常重要的问题。近年来,得益于深度学习技术强大的学习能力,该问题有了显著的进展1。
2.2 SRResNet
SRResNet 网络来源于SRGAN的生成器,允许修复更高频的细节。SRResNet上存在两个小的更改:一个是 SRResNet 使用 Parametric ReLU 而不是 ReLU,ReLU 引入一个可学习参数帮助它适应性地学习部分负系数;另一个区别是 SRResNet 使用了图像上采样方法,SRResNet 使用了子像素卷积层2。
3. 技术实现
3.1 开发环境
1 pytorch == '1.7.0+cu101'
2 numpy == '1.19.4'
3 PIL == '8.0.1'
[也是比较新的版本了]
3.2 数据处理部分
在这里使用的是Urban100数据集,当然使用其他数据集也没有太大的问题(不建议使用带有灰度图的数据集,会报错)
首先,我们需要构建一个数据集类,这里使用的是Pytorch的方法,即继承Pytorch的Dataset类.
数据集处理需要使用到以下库
import torch
from torch.utils.data import Dataset
import numpy as np
import os
from PIL import Image
随后,我们便可以构建出数据集类,首先我们需要定义好图像的处理操作,在这里使用的是随机裁剪+转为张量操作,我们将图片随机裁剪出(96 x 96)的大小并作为高像素图像,然后再通过最大池化的方式将其下采样为(24 x 24)作为低像素图像。
#图像处理操作,包括随机裁剪,转换张量
transform = transforms.Compose([transforms.RandomCrop(96),
transforms.ToTensor()])
随后,我们便可以通过继承Pytorch的 Dataset 类来构建我们的数据集
class PreprocessDataset(Dataset):
"""预处理数据集类"""
def __init__(self,imgPath = path,transforms = transform, ex = 10):
"""初始化预处理数据集类"""
self.transforms = transform
for _,_,files in os.walk(imgPath):
#ex变量是用于扩充数据集的,在这里默认的是扩充十倍
self.imgs = [imgPath + file for file in files] * ex
np.random.shuffle(self.imgs) #随机打乱
def __len__(self):
"""获取数据长度"""
return len(self.imgs)
def __getitem__(self,index):
"""获取数据"""
tempImg = self.imgs[index]
tempImg = Image.open(tempImg)
sourceImg = self.transforms(tempImg) #对原始图像进行处理
cropImg = torch.nn.MaxPool2d(4)(sourceImg)
return cropImg,sourceImg
注意的是,里面的__len__和__getitem__方法是必须的,这两个方法使得我们后面可以很轻松地进行调用。
3.3 模型搭建
我们使用的是比较简单的SRResNet模型,本文的模型和官方模型稍有不同,在这里使用了反射填充的方式来维持输入图像和输出图形之间的大小关系。
[再放一次图片]
模型搭建所需要额外引入的库为
import torch.nn as nn
import torch.nn.functional as F
首先,我们先构建网路中的残差模块,这里的代码思路借鉴了Pytorch官方ResNet的实现思路
class ResBlock(nn.Module):
"""残差模块"""
def __init__(self,inChannals,outChannals):
"""初始化残差模块"""
super(ResBlock,self).__init__()
self.conv1 = nn.Conv2d(inChannals,outChannals,kernel_size=1,bias=False)
self.bn1 = nn.BatchNorm2d(outChannals)
self.conv2 = nn.Conv2d(outChannals,outChannals,kernel_size=3,stride=1,padding=1,bias=False)
self.bn2 = nn.BatchNorm2d(outChannals)
self.conv3 = nn.Conv2d(outChannals,outChannals,kernel_size=1,bias=False)
self.relu = nn.PReLU()
def forward(self,x):
"""前向传播过程"""
resudial = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(x)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(x)
out += resudial
out = self.relu(out)
return out
[不得不说,比Tensorflow复杂了好多]
随后,我们可以对整个模型进行构建,注意的是我们使用的是反射填充,这样能够有效提高图像边缘的修复质量。
class SRResNet(nn.Module):
"""SRResNet模型(4x)"""
def __init__(self):
"""初始化模型配置"""
super(SRResNet,self).__init__()
#卷积模块1
self.conv1 = nn.Conv2d(3,64,kernel_size=9,padding=4,padding_mode='reflect',stride=1)
self.relu = nn.PReLU()
#残差模块
self.resBlock = self._makeLayer_(ResBlock,64,64,16)
#卷积模块2
self.conv2 = nn.Conv2d(64,64,kernel_size=1,stride=1)
self.bn2 = nn.BatchNorm2d(64)
self.relu2 = nn.PReLU()
#子像素卷积
self.convPos1 = nn.Conv2d(64,256,kernel_size=3,stride=1,padding=2,padding_mode='reflect')
self.pixelShuffler1 = nn.PixelShuffle(2)
self.reluPos1 = nn.PReLU()
self.convPos2 = nn.Conv2d(64,256,kernel_size=3,stride=1,padding=1,padding_mode='reflect')
self.pixelShuffler2 = nn.PixelShuffle(2)
self.reluPos2 = nn.PReLU()
self.finConv = nn.Conv2d(64,3,kernel_size=9,stride=1)
def _makeLayer_(self,block,inChannals,outChannals,blocks):
"""构建残差层"""
layers = []
layers.append(block(inChannals,outChannals))
for i in range(1,blocks):
layers.append(block(outChannals,outChannals))
return nn.Sequential(*layers)
def forward(self,x):
"""前向传播过程"""
x = self.conv1(x)
x = self.relu(x)
residual = x
out = self.resBlock(x)
out = self.conv2(out)
out = self.bn2(out)
out += residual
out = self.convPos1(out)
out = self.pixelShuffler1(out)
out = self.reluPos1(out)
out = self.convPos2(out)
out = self.pixelShuffler2(out)
out = self.reluPos2(out)
out = self.finConv(out)
return out
至此,模型和数据集类都已经搭建完毕,接下来就可以进行正式的训练了。
3.4 训练过程
本文操作系统为Windows 10,使用1块RTX 2070进行加速运算,30Epoch 时间约为12分钟。
在训练过程额外使用到的库为
from torch.utils.data import DataLoader
import torch.optim as optim
from tqdm import tqdm
首先,我们将初始化文件路径以及设备信息,以及初始化网络
path = './Urban100/'
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net = SRResNet()
net.to(device)
随后,我们构建好数据集,这里使用的batch大小为32.
BATCH = 32
#构建数据集
processDataset = PreprocessDataset()
trainData = DataLoader(processDataset,batch_size=BATCH)
随后,初始化迭代器和损失函数,这里使用到的是Adam迭代器。
损失函数使用的是原作者使用的均方差损失函数。
optimizer = optim.Adam(net.parameters(),lr=0.001) #初始化迭代器
lossF = nn.MSELoss().to(device) #初始化损失函数
至此一切准备都已完成,现在可以进行正式的训练了
EPOCHS = 30
history = []
for epoch in range(EPOCHS):
net.train()
runningLoss = 0.0
for i,(cropImg,sourceImg) in tqdm(enumerate(trainData, 1)):
cropImg,sourceImg = cropImg.to(device),sourceImg.to(device)
#清空梯度流
optimizer.zero_grad()
#进行训练
outputs = net(cropImg)
loss = lossF(outputs,sourceImg)
loss.backward() #反向传播
optimizer.step()
runningLoss += loss.item()
averageLoss = runningLoss/(i+1)
history += [averageLoss]
print('[INFO] Epoch %d loss: %.3f' %(epoch+1,averageLoss))
runningLoss = 0.0
print('[INFO] Finished Training \nWuhu~')
开始训练后,将会有如下结果 [因开发环境不同而异]
32it [00:24, 1.31it/s]
0it [00:00, ?it/s]
[INFO] Epoch 23 loss: 0.017
32it [00:25, 1.28it/s]
0it [00:00, ?it/s]
[INFO] Epoch 24 loss: 0.016
32it [00:24, 1.30it/s]
0it [00:00, ?it/s]
[INFO] Epoch 25 loss: 0.016
32it [00:24, 1.30it/s]
0it [00:00, ?it/s]
[INFO] Epoch 26 loss: 0.016
32it [00:24, 1.32it/s]
0it [00:00, ?it/s]
[INFO] Epoch 27 loss: 0.016
32it [00:24, 1.32it/s]
0it [00:00, ?it/s]
[INFO] Epoch 28 loss: 0.016
32it [00:25, 1.28it/s]
0it [00:00, ?it/s]
[INFO] Epoch 29 loss: 0.016
32it [00:24, 1.29it/s]
[INFO] Epoch 30 loss: 0.016
[INFO] Finished Training
Wuhu~
3.5 结果可视化
超分辨率有多种分析方法,在此不一一展开介绍。
首先对模型的训练过程分析,
import matplotlib.pyplot as plt
plt.plot(history,label='Loss')
plt.legend(loc='best')
下面是一个跑了214次的模型,结果为
[这里有个BUG,损失值我多加1了]
能够看出模型训练是有效果的。
其次,构造一个可视化的输出函数,该函数读取一个图片地址,并且按照给定地址进行保存
from PIL import Image
def imshow(path,outputPath):
"""展示结果"""
preTransform = transforms.Compose([transforms.ToTensor()])
img = Image.open(path)
img = preTransform(img).unsqueeze(0)
#使用cpu就行
net.cpu()
source = net(img)[0,:,:,:]
source = source.cpu().detach().numpy() #转为numpy
source = source.transpose((1,2,0)) #切换形状
source = np.clip(source,0,1) #修正图片
img = Image.fromarray(np.uint8(source*255))
img.save(outputPath) # 将数组保存为图片
在这里使用了经典的Set5图片,其结果为
[左下角的为原图,分辨率为256x256]