详尽基础:基于PyTorch的超分重建

点击上方“机器学习与生成对抗网络”,关注星标

获取有趣、好玩的前沿干货!

好消息,本文末免费送书!

  在深度学习中,像Unet这种类似encoder+decoder结构并且输入和输出均为图片的网络有非常广的应用范围。


超分辨率重建算法的发展

   分辨率重建指的是将一副低分辨率的图片进行处理,恢复出高分辨率图片的一种图像处理技术。这种技术可以改善图像的视觉效果,也能帮助对图像进行进一步的识别和处理。目前,基于深度学习的超分辨率重建算法已经成为该领域的研究热点。下面介绍一下几种经典的深度学习超分辨率重建算法。

1.    SRCNN

   SRCNN[1]是最早的超分辨率重建算法,先使用双线性插值将图片缩放到期望的大小,然后使用非线性网络进行特征提取和重建,只用到了两个卷积层,其结构如图1所示。


[1] Dong C, Loy C C, He K, et al. Image super-resolution using deepconvolutional networks[J]. IEEE transactions on pattern analysis and machineintelligence, 2015, 38(2): 295-307.

图1    SRCNN 结构示意图

   从这个网络中可以看到,超分辨率重建问题对网络结构的要求并不高,这种简单到极致的网络都可以轻松完成任务。

2.    FSRCNN

   对SRCNN的改进,FSRCNN[1]中的创新点如下。

  • 采用反卷积来放大图片,这样在进行不同比例的超分辨率重建时,只需训练反卷积部分的参数即可,其余层的参数可以保持不变。


[1] Dong C, Loy C C, Tang X. Accelerating the super-resolutionconvolutional neural network[C]//European conference on computer vision.Springer, Cham, 2016: 391-407.

  • 使用1x1卷积来进行降维,减少了模型计算量。

  • 使用更小的卷积核和更多的卷积层。

  图2是FSRCNN的结构示意图,从图中可以看出,模糊图片经过多层卷积之后,得到一个特征图,再使用反卷积和1x1卷积将特征图放大和降维,就可以得到最终的高清图片,只需训练反卷积部分就可以实现多种不同比例的超分辨率重建模型了。[a1] 

图2    FSRCNN示意图

3.    VDSR

在分割网络中使用了残差网络,也就是将训练目标从高清图片转化成了高清图片与低清图片之间的像素差值。这个算法的创新点如下。

  • 使用了残差结构,并在训练中添加了梯度剪裁操作,防止梯度爆炸。

  • 将网络加深到20层,使模型具备了更大的感受野。

  • 将不同缩放比例的图片混合在一起训练,这样模型能够解决不同倍数的高分辨率重建。

  VDSR[1]的网络结构图如图3所示,VDSR的卷积网络变得更深,图片经过多层卷积之后得到的计算结果会与原图相加,得到最终的高清图片,在这种结构下,模型拟合的是高清图片和模糊图片之间的残差,比直接拟合高清图片更加容易。


[1] Kim J, Kwon Lee J, Mu Lee K. Accurate image super-resolution usingvery deep convolutional networks[C]//Proceedings of the IEEE conference oncomputer vision and pattern recognition. 2016: 1646-1654.

图3  VDSR示意图

数据加载

   这个任务的数据生成很简单,把搜集来的任意图片集作为标签,然后借助OpenCV或者PIL等工具将这些图片进行模糊化,即可得到训练数据。

   为了让模型拟合更快,可以选择特定的某一类图片来训练。比如在超分辨率重建的开山之作SRCNN中展示模型效果时,使用的是蝴蝶图片,那么这里也可以选择蝴蝶图片来进行训练,其图片下载方式与第2章中的物体检测相同,搜索“蝴蝶特写”之类的关键词,可以很容易搜到如图4所示的图片。

   本项目中共使用了1381张蝴蝶图片,其中大部分图片只包含了一只蝴蝶的特写,且背景相对简单。

图5‑14蝴蝶图片样例

1.   数据加载

在数据预处理及加载的过程中,我们对图片进行了通道格式转换和通道抽取,并进行了在线模糊化处理,而模糊处理操作选择了PIL库中的ImageFilter.BLUR函数,其代码如下:

# super_resolution_data.py
from torch.utils.data import Dataset
from torchvision import transforms


from glob import glob
import os.path as osp
from PIL import Image, ImageFilter
from sklearn.model_selection import train_test_split


from config import sr_data_folder


class SuperResolutionData(Dataset):
    def __init__(
        self,
        data_folder=sr_data_folder,
        subset="train",
        transform=None,
        demo=False,
    ):
        """
        data_folder: 数据文件夹
        subset: 训练集或者测试集
        transform: 数据增强方法
        demo:demo模式(数据增强方法不同)
        """
        self.img_paths = sorted(glob(osp.join(sr_data_folder, "*.jpg")))
        train_paths, test_paths = train_test_split(
            self.img_paths, test_size=0.2, random_state=10
        )
        # 训练集
        if subset == "train":
            self.img_paths = train_paths
        # 测试集
        else:
            self.img_paths = test_paths
        self.subset = subset
        # demo模式
        self.demo = demo
        # 如果没有定义tranform,则使用默认transform
        if transform is None:
            self.transform = transforms.ToTensor()
        else:
            self.transform = transform


    def __getitem__(self, index):
        # 将高清图片转换成YCbCr
        high = (
            Image.open(self.img_paths[index])
            .resize((256, 256))
            .convert("YCbCr")
        )
        # 划分通道
        high_y, high_cb, high_cr = high.split()
        # 模糊化
        low = high.filter(ImageFilter.BLUR())
        # 划分通道
        low_y, low_cb, low_cr = low.split()
        # 训练集
        if self.subset == "train":
            # demo模式下,返回各个通道
            if self.demo:
                return (
                    self.transform(low_y),
                    self.transform(high_y),
                    (high_cb, high_cr, low_cb, low_cr),
                )
            else:
                return self.transform(low_y), self.transform(high_y)
        # 测试集
        else:
            totensor = transforms.ToTensor()
            if self.demo:
                return (
                    totensor(low_y),
                    totensor(high_y),
                    (high_cb, high_cr, low_cb, low_cr),
                )
            else:
                return totensor(low_y), totensor(high_y)


    def __len__(self):
        return len(self.img_paths)


在上述代码中,实现了超分辨重建数据集,在__init__()方法中,我们加载了所有图片的路径并划分了训练集和验证集;在__getitem__()方法中,我们对图像从RGB格式转换成了YcbCr格式,并进行了通道分割,然后设置了演示模式。在演示模式下,会返回模糊和高清图片的所有通道数据;在非演示模式下,只返回模糊和高清图片的Y通道数据。

2.   图片对比

通过如下代码,可以查看原始图片和模糊化之后的图片:

# tools/show_sample_data.py
# 在tools目录下运行
import torch
from torch import nn
from torchvision.transforms import ToPILImage


import matplotlib.pyplot as plt
from PIL import Image
import sys
# 将上级目录加入系统目录
sys.path.append("..")
from super_resolution_data import SuperResolutionData
# 从测试集中找图片进行演示
test_data = SuperResolutionData(subset="test", demo=True)
low, high, (high_cb, high_cr, low_cb, low_cr) = test_data[0]
topil = ToPILImage()
plt.subplot(121)
plt.title("low")
# 合并通道才能得到一张完整图片
low_rgb = Image.merge("YCbCr", [topil(low), low_cb, low_cr]).convert("RGB")
plt.imshow(low_rgb)
plt.subplot(122)
plt.title("high")
# 合并通道才能得到一张完整图片
high_rgb = Image.merge("YCbCr", [topil(high), high_cb, high_cr]).convert("RGB")
plt.imshow(high_rgb)
plt.savefig("../img/sr_sample.jpg")
plt.show()


上述代码加载了测试集,并从训练集中分别获取到模糊图片和高清图片的三个通道之后,将三个通道合并得到完整的模糊图片和高清图片,最后将两张图片绘制出来。模糊图片与高清图片如图55所示对代码的解释建议再详细一些。

图5  经模糊处理的蝴蝶图片

模型搭建与训练

 

   可以直接使用在图像分割任务中搭建的ResNet18Unet来完成这个任务,但是模型最后的输出类别要改成1.因为本节要以回归的思路搭建这个超分辨率重建模型,直接生成高清图片中的Y通道,然后再与原图中的CbCr通道合并,得到最终的高清图片,根据MSELoss这一回归损失函数来优化模型。

   在本节的超分辨率重建模型的训练过程中,我们使用了两个技巧。

  • 将图片转化成YCbCr通道格式,只训练亮度通道Y。

  • 不直接训练图片,而是训练高清图片和模糊图片之间的残差,这样能减小这个回归问题的训练难度。

    下面是超分辨率重建模型的训练代码:

  • # super_resolution_train.py
    import torch
    from torch import nn, optim
    from torch.utils.data import DataLoader
    
    
    from tqdm import tqdm
    import os.path as osp
    
    
    from super_resolution_data import SuperResolutionData, transform
    from model import ResNet18Unet
    from config import device, sr_checkpoint, batch_size, epoch_lr
    from torch.utils.tensorboard import SummaryWriter
    from transform import TrainTransform, TestTransform
    
    
    def train():
        # 建立模型
        net = ResNet18Unet(num_classes=1)
        # 只训练Y通道
        net.firstconv = nn.Conv2d(
            1, 64, kernel_size=7, stride=2, padding=3, bias=False
        )
        # 将模型转入GPU
        net = net.to(device)
        # 加载数据集
        trainset = SuperResolutionData(subset="train", transform=TrainTransform)
        testset = SuperResolutionData(subset="test", transform=TestTransform)
        # 加载dataloader
        trainloader = DataLoader(
            trainset, batch_size=batch_size, shuffle=True, num_workers=4
        )
        testloader = DataLoader(
            testset, batch_size=batch_size, shuffle=True, num_workers=4
        )
        # 损失函数
        criteron = nn.MSELoss()
        # 最佳损失,用于筛选最佳模型
        best_loss = 1e9
    
    
        if osp.exists(sr_checkpoint):
            ckpt = torch.load(sr_checkpoint)
            best_loss = ckpt["loss"]
            net.load_state_dict(ckpt["params"])
            print("checkpoint loaded ...")
    
    
        writer = SummaryWriter("super_log")
        for n, (num_epochs, lr) in enumerate(epoch_lr):
            optimizer = optim.SGD(
                net.parameters(), lr=lr, momentum=0.9, weight_decay=5e-3
            )
            for epoch in range(num_epochs):
                net.train()
                pbar = tqdm(enumerate(trainloader), total=len(trainloader))
                epoch_loss = 0.0
                for i, (img, mask) in pbar:
                    img = img.to(device)
                    mask = mask.to(device)
                    out = net(img)
                    # 只训练样本与标签之间的残差
                    loss = criteron(out + img, mask)
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    if i % 10 == 0:
                        pbar.set_description("loss: {}".format(loss))
                    epoch_loss += loss.item()
                print("Epoch_loss:{}".format(epoch_loss / len(trainloader.dataset)))
                writer.add_scalar(
                    "super_epoch_loss",
                    epoch_loss / len(trainloader.dataset),
                    sum([e[0] for e in epoch_lr[:n]]) + epoch,
                )
                # 无梯度模式下快速验证
                with torch.no_grad():
                    # 验证模式
                    net.eval()
                    test_loss = 0.0
                    for i, (img, mask) in tqdm(
                        enumerate(testloader), total=len(testloader)
                    ):
                        img = img.to(device)
                        mask = mask.to(device)
                        out = net(img)
                        loss = criteron(out + img, mask)
                        # 累计loss
                        test_loss += loss.item()
                    print(
                        "Test_loss:{}".format(test_loss / len(testloader.dataset))
                    )
                    # 将loss加入tensorboard
                    writer.add_scalar(
                        "super_test_loss",
                        test_loss / len(testloader.dataset),
                        sum([e[0] for e in epoch_lr[:n]]) + epoch,
                    )
                # 如果模型效果比当前最好的模型都好,则保存模型参数
                if test_loss < best_loss:
                    best_loss = test_loss
                    torch.save(
                        {"params": net.state_dict(), "loss": test_loss},
                        sr_checkpoint,
                    )
        writer.close()
    
    
    if __name__ == "__main__":
        train()
    
    
    

上述代码中实现了超分辨率重建模型的训练过程,先使用训练集训练模型,然后在验证集上测试模型效果,如果在验证模型时发现模型的损失值得到了改善,则将改善后的模型保存下来,这样能够避免过拟合之后的模型覆盖掉最优模型。在计算损失时,将模型的预测值out与模型输入值img相加后再与mask计算损失,这种方式能获得更好的效果。

训练过程中的loss变化如图6和图7所示,从图中可以看出,模型在训练集和验证集上的损失较为接近,且在20个epoch之后曲线变得平缓,可以认为模型已经训练到了较理想的状态。

图6      超分辨率重建训练集loss曲线

图7    超分辨率重建测试集loss曲线

模型展示

训练完成之后,可以把生成的图片与样本中的两张图片做一个对比:

# super_solution_demo.py
import torch
from torch import nn
from torchvision.transforms import ToPILImage
import matplotlib.pyplot as plt
from PIL import Image


from model import ResNet18Unet
from super_resolution_data import SuperResolutionData
from config import sr_checkpoint, device


net = ResNet18Unet(num_classes=1)
# 只处理Y通道
net.firstconv = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
net = net.to(device)
net.load_state_dict(torch.load(sr_checkpoint)["params"])
# 从测试集中找图片验证
test_data = SuperResolutionData(subset="test", demo=True)
low, high, (high_cb, high_cr, low_cb, low_cr) = test_data[0]
mask = net(low.unsqueeze(0).to(device)).squeeze(0).data.cpu()
topil = ToPILImage()
plt.subplot(131)
plt.title("low")
# 合并通道
low_rgb = Image.merge("YCbCr", [topil(low), low_cb, low_cr]).convert("RGB")
plt.imshow(low_rgb)
plt.subplot(132)
plt.title("rebuilt")
# 残差累加,还原预测结果
rebuilt = mask + low
# 通道合并
rebuilt_rgb = Image.merge("YCbCr", [topil(rebuilt), low_cb, low_cr]).convert(
    "RGB"
)
plt.imshow(rebuilt_rgb)
plt.subplot(133)
plt.title("high")
high_rgb = Image.merge("YCbCr", [topil(high), high_cb, high_cr]).convert("RGB")
plt.imshow(high_rgb)
plt.savefig("img/sr_result.jpg")
plt.show()


上述代码中实现了超分辨率重建模型的预测过程,分为三个步骤:

1. 首先建立了一个ResNet18Unet模型,然后将模型的输入通道(修改第一个卷积层的输入通道数量)和输出通道(修改最终的输出类别数)都修改成1,然后加载预训练模型参数;

2. 拆分原图的通道,并将Y通道输入到模型中进行前向推理,得到预测结果;

3. 将预测结果与原图中的CbCr两个通道进行合并,得到预测图片;

4. 绘制模糊图片、预测图片和高清图片的对比图。

得到的效果如图8所示,从中可以看到,图片的清晰度有了很大的提升。这说明我们的超分辨率重建模型已经学习到了模糊图片和清晰图片之间的像素映射关系。

图 8    重建前后图片对比

  本文选自----人民邮电出版社出版的《Python计算机视觉与深度学习实战》一书中,经授权此公号。

文末赠书

内容简介

《Python计算机视觉与深度学习实战》立足实践,从机器学习的基础技能出发,深入浅出地介绍了如何使用 Python 进行基于深度学习的计算机视觉项目开发。开篇介绍了基于传统机器学习及图像处理方法的计算机视觉技术;然后重点就图像分类、目标检测、图像分割、图像搜索、图像压缩及文本识别等常见的计算机视觉项目做了理论结合实践的讲解;后探索了深度学习项目落地时会用到的量化、剪枝等技术,并提供了模型服务端部署案例。

~

【活动】

本次为大家免费寄送纸质正版图书!9月17日22点结束并开奖。

参与方法:

1、文末点 在看 

2、公众号后台、或者长按扫下码,回复 168 ,参与抽奖!

  • 2
    点赞
  • 27
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
ECharts(Enterprise Charts)是百度开源的一个基于JavaScript的可视化图表库,它提供了丰富的图表类型和交互功能,可以帮助开发者快速构建各种数据可视化的应用。蝴蝶图(Butterfly Chart)是ECharts中的一种特殊类型的图表,它主要用于展示两个相对独立的数据集之间的对比关系。 蝴蝶图通常由两个相互镜像的柱状图组成,其中一个柱状图表示正值数据,另一个柱状图表示负值数据。通过这种方式,可以直观地比较两个数据集之间的差异,并且可以清晰地展示正负值之间的关系。 在ECharts中,使用蝴蝶图可以通过以下步骤实现: 1. 引入ECharts库和相关依赖文件。 2. 创建一个DOM容器,用于显示图表。 3. 初始化ECharts实例,并设置容器和主题。 4. 配置蝴蝶图的数据和样式,包括正负值数据、颜色、标签等。 5. 将配置项应用到ECharts实例中,并渲染出图表。 以下是一个简单的示例代码,展示了如何使用ECharts创建一个蝴蝶图: ```javascript // 引入ECharts库 import echarts from 'echarts'; // 创建DOM容器 const container = document.getElementById('chart-container'); // 初始化ECharts实例 const chart = echarts.init(container); // 配置蝴蝶图的数据和样式 const option = { tooltip: { trigger: 'axis', axisPointer: { type: 'shadow' } }, legend: { data: ['正值', '负值'] }, xAxis: { type: 'category', data: ['数据1', '数据2', '数据3', '数据4', '数据5'] }, yAxis: { type: 'value' }, series: [ { name: '正值', type: 'bar', stack: '总量', label: { show: true, position: 'inside' }, data: [100, 200, 300, 400, 500] }, { name: '负值', type: 'bar', stack: '总量', label: { show: true, position: 'inside' }, data: [-100, -200, -300, -400, -500] } ] }; // 将配置项应用到ECharts实例中 chart.setOption(option); ``` 这是一个简单的蝴蝶图示例,你可以根据自己的需求进行进一步的配置和样式调整。希望对你有所帮助!

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值