提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档
前言
随着人工智能的不断发展,机器学习这门技术也越来越重要,图像超分辨率(Super-Resolution, SR)同样也借助AI技术性能起飞。本文介绍一种在端侧部署的AI的超分辨率技术
一、图像超分是什么?
图像超分(Super-Resolution, SR)是一种图像处理技术,旨在从低分辨率(LR)图像中重建出高分辨率(HR)图像。这种技术试图恢复或增加图像的细节和清晰度,使其在放大或高分辨率显示设备上观看时更加清晰。
图像超分的挑战:
- 信息丢失:在图像从高分辨率降到低分辨率的过程中,通常会丢失一些高频信息,如细节和边缘信息。
- 放大伪影:简单地通过插值方法放大低分辨率图像,往往会产生不自然的图案和伪影,如锯齿状边缘。
图像超分的方法:
-
插值方法:
- 最简单的超分方法,如双线性插值、双三次插值等,通过估计像素值来放大图像。这些方法计算简单,但难以恢复丢失的高频细节。
-
基于重建的方法:
- 利用某些先验知识或模型来重建图像的细节。例如,稀疏编码(Sparse Coding)方法通过学习图像的稀疏表示来进行超分。
-
基于学习的方法:
- 近年来,基于深度学习的方法在图像超分领域取得了显著的进展。这些方法通常使用卷积神经网络(CNN)来学习从低分辨率到高分辨率图像的映射关系。
- 生成对抗网络(GAN):通过对抗训练的方式,生成器网络学习生成高分辨率图像,而判别器网络则学习区分真实和生成的高分辨率图像。
- 自编码器(Autoencoders):使用编码器将图像压缩成低维表示,然后通过解码器重建高分辨率图像。
- 注意力机制:通过关注图像中的关键部分,提高超分图像的质量。
-
混合方法:
- 结合传统的图像处理技术和深度学习方法,以利用各自的优势。
应用场景:
- 数字摄影:提高手机和相机拍摄的照片的分辨率。
- 视频监控:增强监控视频的清晰度,以便更好地识别细节。
- 医学成像:提高MRI或CT扫描的分辨率,以便于更准确的诊断。
- 卫星图像:提高卫星图像的分辨率,以支持地理信息系统(GIS)和其他分析工具。
图像超分技术在提高图像质量、增强视觉体验以及支持专业分析方面发挥着重要作用。随着技术的发展,图像超分方法正变得越来越精确和高效。
二、使用步骤
这次介绍由我母校南京理工大学的团队设计的MobileSR模型
论文链接:https://openaccess.thecvf.com/content/CVPR2022W/NTIRE/papers/Li_NTIRE_2022_Challenge_on_Efficient_Super-Resolution_Methods_and_Results_CVPRW_2022_paper.pdf
MobileSR模型对输入图像进行超分辨率处理的确切步骤如下:
-
初始特征提取:
- 输入的低分辨率图像首先通过
head
模块,这是一个3x3的卷积层,用于从输入图像中提取初始特征。
- 输入的低分辨率图像首先通过
-
特征处理:
- 初始特征通过
BaseBlock
模块进行处理。BaseBlock
由多个Transformer
和ResBlock
层组成,每个Transformer
层包含自注意力机制和多层感知机(MLP),而ResBlock
层则用于进一步的特征提炼。
- 初始特征通过
-
特征融合:
- 经过
BaseBlock
处理的特征与初始特征通过fuse
模块进行融合。fuse
是一个3x3的卷积层,用于合并特征。
- 经过
-
上采样:
- 融合后的特征通过
upsapling
模块进行上采样。根据上采样因子的不同,upsapling
模块包含一系列卷积层和像素shuffle层,用于将特征图尺寸放大到目标高分辨率图像的尺寸。
- 融合后的特征通过
-
最终重建:
- 上采样后的特征图通过
tail
模块,这是一个3x3的卷积层,用于重建最终的高分辨率图像。
- 上采样后的特征图通过
-
激活函数:
- 通过
act
模块,即LeakyReLU激活函数,增加模型的非线性能力。
- 通过
-
输出:
- 最终,模型输出重建后的高分辨率图像。输出图像与通过双线性插值放大的原始输入图像相加,以改善细节和纹理。
-
测试:
- 在代码的测试部分,创建了一个随机初始化的张量
a
,模拟输入的低分辨率图像,并通过MobileSR
模型进行处理,输出了超分辨率后的图像尺寸。
- 在代码的测试部分,创建了一个随机初始化的张量
这个流程是MobileSR模型处理图像超分辨率的确切步骤,每个步骤都是通过精心设计的网络层和操作来实现的。
三、附上代码
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from einops import rearrange
###########################################################################
# Self-Attention
class SelfAttn(nn.Module):
"""
自注意力机制模块,用于实现Transformer中的注意力计算。
"""
def __init__(self, dim, num_heads=8, bias=False):
super(SelfAttn, self).__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
# QKV线性层,将输入特征映射到查询(Q)、键(K)和值(V)
self.qkv = nn.Linear(dim, dim*3, bias=bias)
# 输出投影层,将注意力输出映射回原始特征空间
self.proj_out = nn.Linear(dim, dim)
def forward(self, x):
"""
前向传播函数,实现自注意力机制。
参数:
x (Tensor): 输入特征,形状为 (b, N, c),其中 b 是批次大小,N 是序列长度,c 是特征维度。
输出:
x (Tensor): 注意力输出特征,形状为 (b, N, c)。
"""
b, N, c = x.shape
# 将输入特征映射到查询(Q)、键(K)和值(V)
qkv = self.qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.num_heads), qkv)
# 计算注意力分数并应用softmax
attn = torch.einsum('bijc, bikc -> bijk', q, k) * self.scale
attn = attn.softmax(dim=-1)
# 计算注意力加权的值
x = torch.einsum('bijk, bikc -> bijc', attn, v)
x = rearrange(x, 'b i j c -> b j (i c)')
x = self.proj_out(x)
return x
class Mlp(nn.Module):
"""
多层感知机模块,用于Transformer中的前馈网络。
"""
def __init__(self, in_features, mlp_ratio=4):
super(Mlp, self).__init__()
hidden_features = in_features * mlp_ratio
self.fc = nn.Sequential(
nn.Linear(in_features, hidden_features),
nn.GELU(),
nn.Linear(hidden_features, in_features)
)
def forward(self, x):
"""
前向传播函数,实现多层感知机。
参数:
x (Tensor): 输入特征,形状为 (b, N, c),其中 b 是批次大小,N 是序列长度,c 是特征维度。
输出:
x (Tensor): MLP输出特征,形状为 (b, N, c)。
"""
return self.fc(x)
def window_partition(x, window_size):
"""
将输入的特征图分割成多个窗口,用于局部自注意力计算。
参数:
x (Tensor): 输入特征图,形状为 (b, h, w, c),其中 b 是批次大小,h 和 w 是特征图的高度和宽度,c 是特征维度。
window_size (int): 窗口大小。
输出:
windows (Tensor): 分割后的窗口,形状为 (num_windows*b, window_size, window_size, c)。
"""
return rearrange(x, 'b (h s1) (w s2) c -> (b h w) s1 s2 c', s1=window_size, s2=window_size)
def window_reverse(windows, window_size, h, w):
"""
将分割的窗口重新组合成完整的特征图。
参数:
windows (Tensor): 分割后的窗口,形状为 (num_windows*b, window_size, window_size, c)。
window_size (int): 窗口大小。
h (int): 特征图的高度。
w (int): 特征图的宽度。
输出:
x (Tensor): 重新组合的特征图,形状为 (b, h, w, c)。
"""
b = int(windows.shape[0] / (h * w / window_size / window_size))
return rearrange(windows, '(b h w) s1 s2 c -> b (h s1) (w s2) c', b=b, h=h//window_size, w=w//window_size)
class Transformer(nn.Module):
"""
Transformer模块,包含自注意力机制和多层感知机。
"""
def __init__(self, dim, num_heads=4, window_size=8, mlp_ratio=4, qkv_bias=False):
super(Transformer, self).__init__()
self.window_size=window_size
self.pos_embed = nn.Conv2d(dim, dim, 3, padding=1, groups=dim)
self.norm1 = nn.LayerNorm(dim)
self.attn = SelfAttn(dim, num_heads, qkv_bias)
self.norm2 = nn.LayerNorm(dim)
self.mlp = Mlp(dim, mlp_ratio)
def forward(self, x):
"""
前向传播函数,实现Transformer模块的计算。
参数:
x (Tensor): 输入特征图,形状为 (b, c, h, w),其中 b 是批次大小,c 是特征维度,h 和 w 是特征图的高度和宽度。
输出:
x (Tensor): Transformer输出特征图,形状为 (b, c, h, w)。
"""
x = x + self.pos_embed(x)
x = rearrange(x, 'b c h w -> b h w c')
b, h, w, c = x.shape
shortcut = x
x = self.norm1(x)
pad_l = pad_t = 0
pad_r = (self.window_size - w % self.window_size) % self.window_size
pad_b = (self.window_size - h % self.window_size) % self.window_size
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
_, Hp, Wp, _ = x.shape
x_windows = window_partition(x, self.window_size)
x_windows = rearrange(x_windows, 'B s1 s2 c -> B (s1 s2) c', s1=self.window_size, s2=self.window_size)
attn_windows = self.attn(x_windows)
attn_windows = rearrange(attn_windows, 'B (s1 s2) c -> B s1 s2 c', s1=self.window_size, s2=self.window_size)
x = window_reverse(attn_windows, self.window_size, Hp, Wp)
if pad_r > 0 or pad_b > 0:
x = x[:, :h, :w, :].contiguous()
x = x + shortcut
x = x + self.mlp(self.norm2(x))
return rearrange(x, 'b h w c -> b c h w')
class ResBlock(nn.Module):
"""
残差块模块,用于特征提炼。
"""
def __init__(self, in_features, ratio=4):
super(ResBlock, self).__init__()
self.net = nn.Sequential(
nn.Conv2d(in_features, in_features*ratio, 1, 1, 0),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(in_features*ratio, in_features*ratio, 3, 1, 1, groups=in_features*ratio),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(in_features*ratio, in_features, 1, 1, 0),
)
def forward(self, x):
"""
前向传播函数,实现残差块的计算。
参数:
x (Tensor): 输入特征图,形状为 (b, c, h, w),其中 b 是批次大小,c 是特征维度,h 和 w 是特征图的高度和宽度。
输出:
x (Tensor): 残差块输出特征图,形状为 (b, c, h, w)。
"""
return self.net(x) + x
class BaseBlock(nn.Module):
"""
基础块模块,包含多个Transformer和ResBlock模块。
"""
def __init__(self, dim, num_heads=8, window_size=8, ratios=[1, 2, 2, 4, 4], qkv_bias=False):
super(BaseBlock, self).__init__()
self.layers = nn.ModuleList([])
for ratio in ratios:
self.layers.append(nn.ModuleList([
Transformer(dim, num_heads, window_size, ratio, qkv_bias),
ResBlock(dim, ratio)
]))
def forward(self, x):
"""
前向传播函数,实现基础块的计算。
参数:
x (Tensor): 输入特征图,形状为 (b, c, h, w),其中 b 是批次大小,c 是特征维度,h 和 w 是特征图的高度和宽度。
输出:
x (Tensor): 基础块输出特征图,形状为 (b, c, h, w)。
"""
for tblock, rblock in self.layers:
x = tblock(x)
x = rblock(x)
return x
class MobileSR(nn.Module):
"""
MobileSR模型,用于图像超分辨率。
"""
def __init__(self, n_feats=40, n_heads=8, ratios=[4, 2, 2, 4], upscaling_factor=4):
super(MobileSR, self).__init__()
self.scale = upscaling_factor
# 初始卷积层,将输入图像映射到特征图
self.head = nn.Conv2d(3, n_feats, 3, 1, 1)
# 基础块模块,包含多个Transformer和ResBlock模块
self.body = BaseBlock(n_feats, num_heads=n_feats, ratios=ratios)
# 融合层,将body的输出与初始特征图融合
self.fuse = nn.Conv2d(n_feats*2, n_feats, 3, 1, 1)
# 上采样模块,根据上采样因子构建不同的上采样网络
if self.scale == 4:
self.upsapling = nn.Sequential(
nn.Conv2d(n_feats, n_feats*4, 1, 1, 0),
nn.PixelShuffle(2),
nn.Conv2d(n_feats, n_feats*4, 1, 1, 0),
nn.PixelShuffle(2)
)
else:
self.upsapling = nn.Sequential(
nn.Conv2d(n_feats, n_feats*self.scale*self.scale, 1, 1, 0),
nn.PixelShuffle(self.scale)
)
# 最终卷积层,将上采样后的特征图映射到输出图像
self.tail = nn.Conv2d(n_feats, 3, 3, 1, 1)
# 激活函数
self.act = nn.LeakyReLU(0.2, inplace=True)
def forward(self, x):
"""
前向传播函数,实现图像超分辨率的计算。
参数:
x (Tensor): 输入低分辨率图像,形状为 (b, 3, h, w),其中 b 是批次大小,3 是颜色通道,h 和 w 是图像的高度和宽度。
输出:
x (Tensor): 超分辨率后的高分辨率图像,形状为 (b, 3, h*scale, w*scale)。
"""
# 通过初始卷积层提取特征
x0 = self.head(x)
# 通过基础块模块处理特征
x0 = self.body(x0)
# 将处理后的特征与初始特征融合
x0 = self.fuse(torch.cat([x0, self.body(x0)], dim=1))
# 上采样处理
x0 = self.upsapling(x0)
# 通过最终卷积层重建图像
x0 = self.tail(self.act(x0))
# 双线性插值上采样
x = F.interpolate(x, scale_factor=self.scale, mode='bilinear', align_corners=False)
# 返回超分辨率后的图像与原始图像的和
return x0 + x
if __name__== '__main__':
"""
测试MobileSR模型。
"""
# 创建一个随机初始化的张量,模拟输入的低分辨率图像
a = torch.randn(1, 3, 23, 37)
# 创建MobileSR模型实例
model = MobileSR()
# 通过模型进行前向传播,得到超分辨率后的图像
output = model(a)
# 打印输出图像的形状
print(output.shape)
这个注释详细解释了每个类和函数的作用,以及它们在图像超分辨率任务中的具体用途。每个函数的输入和输出都进行了详细的说明,以便更好地理解MobileSR模型的工作原理和处理流程。
附上源代码链接:https://github.com/sunny2109/MobileSR-NTIRE2022
四、实测:
原图(1788 x 804)
SR后(×4)
总结
以上就是今天要讲的内容,本文简单介绍了基于端侧AI超分辨率技术,以及mobileSR超分模型的原理和使用