基于端侧AI的超分辨率技术

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档


前言

随着人工智能的不断发展,机器学习这门技术也越来越重要,图像超分辨率(Super-Resolution, SR)同样也借助AI技术性能起飞。本文介绍一种在端侧部署的AI的超分辨率技术


一、图像超分是什么?

图像超分(Super-Resolution, SR)是一种图像处理技术,旨在从低分辨率(LR)图像中重建出高分辨率(HR)图像。这种技术试图恢复或增加图像的细节和清晰度,使其在放大或高分辨率显示设备上观看时更加清晰。

图像超分的挑战:

  • 信息丢失:在图像从高分辨率降到低分辨率的过程中,通常会丢失一些高频信息,如细节和边缘信息。
  • 放大伪影:简单地通过插值方法放大低分辨率图像,往往会产生不自然的图案和伪影,如锯齿状边缘。

图像超分的方法:

  1. 插值方法

    • 最简单的超分方法,如双线性插值、双三次插值等,通过估计像素值来放大图像。这些方法计算简单,但难以恢复丢失的高频细节。
  2. 基于重建的方法

    • 利用某些先验知识或模型来重建图像的细节。例如,稀疏编码(Sparse Coding)方法通过学习图像的稀疏表示来进行超分。
  3. 基于学习的方法

    • 近年来,基于深度学习的方法在图像超分领域取得了显著的进展。这些方法通常使用卷积神经网络(CNN)来学习从低分辨率到高分辨率图像的映射关系。
    • 生成对抗网络(GAN):通过对抗训练的方式,生成器网络学习生成高分辨率图像,而判别器网络则学习区分真实和生成的高分辨率图像。
    • 自编码器(Autoencoders):使用编码器将图像压缩成低维表示,然后通过解码器重建高分辨率图像。
    • 注意力机制:通过关注图像中的关键部分,提高超分图像的质量。
  4. 混合方法

    • 结合传统的图像处理技术和深度学习方法,以利用各自的优势。

应用场景:

  • 数字摄影:提高手机和相机拍摄的照片的分辨率。
  • 视频监控:增强监控视频的清晰度,以便更好地识别细节。
  • 医学成像:提高MRI或CT扫描的分辨率,以便于更准确的诊断。
  • 卫星图像:提高卫星图像的分辨率,以支持地理信息系统(GIS)和其他分析工具。

图像超分技术在提高图像质量、增强视觉体验以及支持专业分析方面发挥着重要作用。随着技术的发展,图像超分方法正变得越来越精确和高效。

二、使用步骤

这次介绍由我母校南京理工大学的团队设计的MobileSR模型
论文链接:https://openaccess.thecvf.com/content/CVPR2022W/NTIRE/papers/Li_NTIRE_2022_Challenge_on_Efficient_Super-Resolution_Methods_and_Results_CVPRW_2022_paper.pdf
MobileSR模型对输入图像进行超分辨率处理的确切步骤如下:

  1. 初始特征提取

    • 输入的低分辨率图像首先通过head模块,这是一个3x3的卷积层,用于从输入图像中提取初始特征。
  2. 特征处理

    • 初始特征通过BaseBlock模块进行处理。BaseBlock由多个TransformerResBlock层组成,每个Transformer层包含自注意力机制和多层感知机(MLP),而ResBlock层则用于进一步的特征提炼。
  3. 特征融合

    • 经过BaseBlock处理的特征与初始特征通过fuse模块进行融合。fuse是一个3x3的卷积层,用于合并特征。
  4. 上采样

    • 融合后的特征通过upsapling模块进行上采样。根据上采样因子的不同,upsapling模块包含一系列卷积层和像素shuffle层,用于将特征图尺寸放大到目标高分辨率图像的尺寸。
  5. 最终重建

    • 上采样后的特征图通过tail模块,这是一个3x3的卷积层,用于重建最终的高分辨率图像。
  6. 激活函数

    • 通过act模块,即LeakyReLU激活函数,增加模型的非线性能力。
  7. 输出

    • 最终,模型输出重建后的高分辨率图像。输出图像与通过双线性插值放大的原始输入图像相加,以改善细节和纹理。
  8. 测试

    • 在代码的测试部分,创建了一个随机初始化的张量a,模拟输入的低分辨率图像,并通过MobileSR模型进行处理,输出了超分辨率后的图像尺寸。

这个流程是MobileSR模型处理图像超分辨率的确切步骤,每个步骤都是通过精心设计的网络层和操作来实现的。

三、附上代码

import torch
from torch import nn 
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from einops import rearrange

###########################################################################
# Self-Attention 
class SelfAttn(nn.Module):
    """
    自注意力机制模块,用于实现Transformer中的注意力计算。
    """
    def __init__(self, dim, num_heads=8, bias=False):
        super(SelfAttn, self).__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        # QKV线性层,将输入特征映射到查询(Q)、键(K)和值(V)
        self.qkv = nn.Linear(dim, dim*3, bias=bias)
        # 输出投影层,将注意力输出映射回原始特征空间
        self.proj_out = nn.Linear(dim, dim)

    def forward(self, x):
        """
        前向传播函数,实现自注意力机制。
        
        参数:
        x (Tensor): 输入特征,形状为 (b, N, c),其中 b 是批次大小,N 是序列长度,c 是特征维度。
        
        输出:
        x (Tensor): 注意力输出特征,形状为 (b, N, c)。
        """
        b, N, c = x.shape
        
        # 将输入特征映射到查询(Q)、键(K)和值(V)
        qkv = self.qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.num_heads), qkv)
        
        # 计算注意力分数并应用softmax
        attn = torch.einsum('bijc, bikc -> bijk', q, k) * self.scale
        attn = attn.softmax(dim=-1)
        # 计算注意力加权的值
        x = torch.einsum('bijk, bikc -> bijc', attn, v)
        x = rearrange(x, 'b i j c -> b j (i c)')
        x = self.proj_out(x)
        return x
    
class Mlp(nn.Module):
    """
    多层感知机模块,用于Transformer中的前馈网络。
    """
    def __init__(self, in_features, mlp_ratio=4):
        super(Mlp, self).__init__()
        hidden_features = in_features * mlp_ratio
        
        self.fc = nn.Sequential(
            nn.Linear(in_features, hidden_features),
            nn.GELU(),
            nn.Linear(hidden_features, in_features)
        )

    def forward(self, x):
        """
        前向传播函数,实现多层感知机。
        
        参数:
        x (Tensor): 输入特征,形状为 (b, N, c),其中 b 是批次大小,N 是序列长度,c 是特征维度。
        
        输出:
        x (Tensor): MLP输出特征,形状为 (b, N, c)。
        """
        return self.fc(x)

    
def window_partition(x, window_size):
    """
    将输入的特征图分割成多个窗口,用于局部自注意力计算。
    
    参数:
    x (Tensor): 输入特征图,形状为 (b, h, w, c),其中 b 是批次大小,h 和 w 是特征图的高度和宽度,c 是特征维度。
    window_size (int): 窗口大小。
    
    输出:
    windows (Tensor): 分割后的窗口,形状为 (num_windows*b, window_size, window_size, c)。
    """
    return rearrange(x, 'b (h s1) (w s2) c -> (b h w) s1 s2 c', s1=window_size, s2=window_size)


def window_reverse(windows, window_size, h, w):
    """
    将分割的窗口重新组合成完整的特征图。
    
    参数:
    windows (Tensor): 分割后的窗口,形状为 (num_windows*b, window_size, window_size, c)。
    window_size (int): 窗口大小。
    h (int): 特征图的高度。
    w (int): 特征图的宽度。
    
    输出:
    x (Tensor): 重新组合的特征图,形状为 (b, h, w, c)。
    """
    b = int(windows.shape[0] / (h * w / window_size / window_size))
    return rearrange(windows, '(b h w) s1 s2 c -> b (h s1) (w s2) c', b=b, h=h//window_size, w=w//window_size)
    
    
class Transformer(nn.Module):
    """
    Transformer模块,包含自注意力机制和多层感知机。
    """
    def __init__(self, dim, num_heads=4, window_size=8, mlp_ratio=4, qkv_bias=False):
        super(Transformer, self).__init__()
        self.window_size=window_size
        self.pos_embed = nn.Conv2d(dim, dim, 3, padding=1, groups=dim)
        
        self.norm1 = nn.LayerNorm(dim)
        self.attn = SelfAttn(dim, num_heads, qkv_bias)
        self.norm2 = nn.LayerNorm(dim)
      
        self.mlp = Mlp(dim, mlp_ratio)

    def forward(self, x):
        """
        前向传播函数,实现Transformer模块的计算。
        
        参数:
        x (Tensor): 输入特征图,形状为 (b, c, h, w),其中 b 是批次大小,c 是特征维度,h 和 w 是特征图的高度和宽度。
        
        输出:
        x (Tensor): Transformer输出特征图,形状为 (b, c, h, w)。
        """
        x = x + self.pos_embed(x)
        x = rearrange(x, 'b c h w -> b h w c')
        b, h, w, c = x.shape
        
        shortcut = x
        x = self.norm1(x)

        pad_l = pad_t = 0
        pad_r = (self.window_size - w % self.window_size) % self.window_size
        pad_b = (self.window_size - h % self.window_size) % self.window_size
        x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
        _, Hp, Wp, _ = x.shape
        
        x_windows = window_partition(x, self.window_size)  
        x_windows = rearrange(x_windows, 'B s1 s2 c -> B (s1 s2) c', s1=self.window_size, s2=self.window_size)

        attn_windows = self.attn(x_windows)  
        attn_windows = rearrange(attn_windows, 'B (s1 s2) c -> B s1 s2 c', s1=self.window_size, s2=self.window_size)
        x = window_reverse(attn_windows, self.window_size, Hp, Wp)  

        if pad_r > 0 or pad_b > 0:
            x = x[:, :h, :w, :].contiguous()

        x = x + shortcut
        x = x + self.mlp(self.norm2(x))
        return rearrange(x, 'b h w c -> b c h w')


class ResBlock(nn.Module):
    """
    残差块模块,用于特征提炼。
    """
    def __init__(self, in_features, ratio=4):
        super(ResBlock, self).__init__()
    
        self.net = nn.Sequential(
            nn.Conv2d(in_features, in_features*ratio, 1, 1, 0),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(in_features*ratio, in_features*ratio, 3, 1, 1, groups=in_features*ratio),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(in_features*ratio, in_features, 1, 1, 0),
        )

    def forward(self, x):
        """
        前向传播函数,实现残差块的计算。
        
        参数:
        x (Tensor): 输入特征图,形状为 (b, c, h, w),其中 b 是批次大小,c 是特征维度,h 和 w 是特征图的高度和宽度。
        
        输出:
        x (Tensor): 残差块输出特征图,形状为 (b, c, h, w)。
        """
        return self.net(x) + x
    
    
class BaseBlock(nn.Module):
    """
    基础块模块,包含多个Transformer和ResBlock模块。
    """
    def __init__(self, dim, num_heads=8, window_size=8, ratios=[1, 2, 2, 4, 4], qkv_bias=False):
        super(BaseBlock, self).__init__()
        self.layers = nn.ModuleList([])
        for ratio in ratios:
            self.layers.append(nn.ModuleList([ 
                Transformer(dim, num_heads, window_size, ratio, qkv_bias),
                ResBlock(dim, ratio)
            ]))
            
    def forward(self, x):
        """
        前向传播函数,实现基础块的计算。
        
        参数:
        x (Tensor): 输入特征图,形状为 (b, c, h, w),其中 b 是批次大小,c 是特征维度,h 和 w 是特征图的高度和宽度。
        
        输出:
        x (Tensor): 基础块输出特征图,形状为 (b, c, h, w)。
        """
        for tblock, rblock in self.layers:
            x = tblock(x)
            x = rblock(x)
        return x 
   
    
class MobileSR(nn.Module):
    """
    MobileSR模型,用于图像超分辨率。
    """
    def __init__(self, n_feats=40, n_heads=8, ratios=[4, 2, 2, 4], upscaling_factor=4):
        super(MobileSR, self).__init__()
        self.scale = upscaling_factor 
        # 初始卷积层,将输入图像映射到特征图
        self.head = nn.Conv2d(3, n_feats, 3, 1, 1)
        
        # 基础块模块,包含多个Transformer和ResBlock模块
        self.body = BaseBlock(n_feats, num_heads=n_feats, ratios=ratios)
    
        # 融合层,将body的输出与初始特征图融合
        self.fuse = nn.Conv2d(n_feats*2, n_feats, 3, 1, 1)
        
        # 上采样模块,根据上采样因子构建不同的上采样网络
        if self.scale == 4:
            self.upsapling = nn.Sequential(
                nn.Conv2d(n_feats, n_feats*4, 1, 1, 0),
                nn.PixelShuffle(2),
                nn.Conv2d(n_feats, n_feats*4, 1, 1, 0),
                nn.PixelShuffle(2)
            )
        else:
            self.upsapling = nn.Sequential(
                nn.Conv2d(n_feats, n_feats*self.scale*self.scale, 1, 1, 0),
                nn.PixelShuffle(self.scale)
            )
        
        # 最终卷积层,将上采样后的特征图映射到输出图像
        self.tail = nn.Conv2d(n_feats, 3, 3, 1, 1)
        # 激活函数
        self.act = nn.LeakyReLU(0.2, inplace=True) 
        
    def forward(self, x):
        """
        前向传播函数,实现图像超分辨率的计算。
        
        参数:
        x (Tensor): 输入低分辨率图像,形状为 (b, 3, h, w),其中 b 是批次大小,3 是颜色通道,h 和 w 是图像的高度和宽度。
        
        输出:
        x (Tensor): 超分辨率后的高分辨率图像,形状为 (b, 3, h*scale, w*scale)。
        """
        # 通过初始卷积层提取特征
        x0 = self.head(x)
        # 通过基础块模块处理特征
        x0 = self.body(x0)
        # 将处理后的特征与初始特征融合
        x0 = self.fuse(torch.cat([x0, self.body(x0)], dim=1))
        # 上采样处理
        x0 = self.upsapling(x0)
        # 通过最终卷积层重建图像
        x0 = self.tail(self.act(x0))
        # 双线性插值上采样
        x = F.interpolate(x, scale_factor=self.scale, mode='bilinear', align_corners=False)
        # 返回超分辨率后的图像与原始图像的和
        return x0 + x  
    
if __name__== '__main__':
    """
    测试MobileSR模型。
    """
    # 创建一个随机初始化的张量,模拟输入的低分辨率图像
    a = torch.randn(1, 3, 23, 37)
    # 创建MobileSR模型实例
    model = MobileSR()
    # 通过模型进行前向传播,得到超分辨率后的图像
    output = model(a)
    # 打印输出图像的形状
    print(output.shape)

这个注释详细解释了每个类和函数的作用,以及它们在图像超分辨率任务中的具体用途。每个函数的输入和输出都进行了详细的说明,以便更好地理解MobileSR模型的工作原理和处理流程。
附上源代码链接:https://github.com/sunny2109/MobileSR-NTIRE2022

四、实测:

在这里插入图片描述
原图(1788 x 804)

SR后(×4)


总结

以上就是今天要讲的内容,本文简单介绍了基于端侧AI超分辨率技术,以及mobileSR超分模型的原理和使用

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值