文章目录
一、论文详解
- 中文:使用非常深的残差通道注意力网络实现图像超分辨率
- 论文:Image Super-Resolution Using Very Deep Residual Channel Attention Networks
- 期刊:ECCV
- 时间:2018 年 7 月 12 日
1.1、项目背景
超分辨率图像(SR,Super-Resolution):旨在从一幅低分辨率(LR,Low-Resolution)图像中恢复一幅高分辨率(HR,High-Resolution)图像。这个任务在医学影像、卫星图像处理、安防监控以及计算机视觉的多个领域中都有重要应用。
在超分任务中,常见的三类图像如下:
低分辨率图像(Low-Resolution, LR)
:像素密度较低、尺寸较小、细节不清晰。通常是由于拍摄设备的分辨率较低、图像采样过少(采样步长较大)或原图下采样等原因造成的。例如,拍照时使用低质量摄像头,或者通过网络传输过程中图像质量压缩,都会导致图像分辨率降低。低分辨率图像常用于节省存储和带宽,但在许多应用中,低分辨率图像中的细节往往不足以支持精确的分析和理解。高分辨率图像(High-Resolution, HR)
:像素密度较高,尺寸较大,保留更多的细节。高分辨率图像通常是通过高质量设备拍摄,或者在图像生成过程中通过超分辨率技术恢复的图像。高分辨率图像具有更清晰的细节和更高的图像质量,但它们占用更多的存储空间和带宽。
超分辨率图像(Super-Resolution, SR)
:通过超分辨率模型(如 RCAN、ESRGAN、SwinIR)从 LR 预测得到的 HR 图像。目标是尽可能逼近 HR 图像,使其具有更多细节、更好的清晰度。
在超分实验中,常见的处理流程如下:
- 训练阶段:输入 LR 图像,模型学习如何恢复出 SR 图像,并与 HR 进行对比优化。
- 测试阶段:将 LR 输入训练好的超分模型,得到 SR 结果,并计算
PSNR
、SSIM
等指标评估质量。
1.2、研究现状
(1)使用 Bicubic(双三次插值,BI) 对高分辨率(HR)图像进行降质(先下采样4X,然后上采样4X),以模拟低分辨率(LR)图像。(2)将生成的 LR 图像输入不同的 SR 网络进行恢复,比较它们的性能。
近来,基于深度卷积神经网络(CNN)的方法在超分辨率(SR)方面取得了显著进展,超越了传统方法。
- 研究现状:
(组合拳)网络深度 + 残差网络
- Dong 等人率先提出了
SRCNN
,首次将三层 CNN 应用于图像超分辨率。- Kim 等人在
VDSR
和DRCN
中将网络深度增加到 20 层,相比 SRCNN 取得了显著的性能提升。- 网络深度对于许多视觉识别任务至关重要,尤其是当 He 等人提出 残差网络(ResNet) 之后。Lim 等人通过使用简化的残差块构建了非常宽的网络
EDSR
和非常深的网络MDSR
。EDSR和 MDSR 在性能上的巨大提升表明,表示的深度对于图像超分辨率至关重要。- 问题1:仅仅通过堆叠残差块来构建更深的网络很难获得更好的改进。
- 问题2:现有方法对每个通道特征同等对待(一视同仁),导致网络缺乏跨特征通道的判别式学习能力,从而限制了网络的表征能力。
1.3、论文核心
模型目标:图像超分辨率(SR)可以被视为一个图像恢复过程,其尽可能多地从低分辨率(LR)图像中恢复高频信息,得到高分辨率(HR)图像。 低分辨率图像包含大部分低频信息(如平滑区域、背景、轮廓和结构等)以及少部分高频信息(如边缘、纹理、噪声或锐利的结构变化等)
提出了一种残差通道注意力网络(RCAN,residual channel attention networks),以获得非常深的可训练网络,并同时自适应地学习更多有用的通道特征。
- (1)提出了一种
残差中的残差(RIR,residual in residual)结构
,来构建非常深的卷积神经网络(超过400层)。- (2)提出了一种
通道注意力(CA,channel attention)机制
,通过考虑通道之间的相互依赖关系来自适应地缩放通道特征.- 组成:
- 1)RIR网络:由 G 个残差组(RG,residual group)和 1 条长跳跃连接(LSC,long skip connection)组成
- 2)每个RG:由 B 个残差通道注意力块(RCAB,residual channel attention blocks)和 1 条短跳跃连接(SSC,short skip connection)组成
- 3)每个RCAB:由 1 个通道注意力(CA)和 1 个残差块(RB,residual blocks)组成,每个RB具有 1 条短跳跃链接
- 目前,基于注意力机制的一些尝试性研究工作已应用于深度神经网络,但针对低级视觉任务(如图像超分辨率)却鲜有研究工作。
- 特点:基于恒等映射的跳跃连接(LSC + SSC + 残差块中的短连接)绕过大量低频信息,使主网络专注于学习高频信息。
大量实验表明,RCAN 在精度和视觉效果方面优于最先进的方法。
1.4、网络模型(RCAN,Residual Channel Attention Networks)
RCAN由四个部分组成:
- (1)低分辨率图像输入:
I-LR
- (2)浅层网络提取特征:Conv(提取F0)
- (3)深层网络提取特征:RIR —— RG1(输入F0,输出Fg-G),RG2(输入Fg-G,输出[Fg-(G-1)]),…,RG-G(输入Fg,输出F-DF)
- 在 RIR 中,
引入了 1 条长跳跃连接(Long skip connection,LSC)
—— 由 F0 连接到 F-DF- 在每个 RG 中,
引入了 1 条短跳跃连接(Short skip connection,SSC)
—— 由 Fg-1 连接到 Fg- (4)上采样模块:Upscale module(输入F-DF,输出F-UP)
- (5)重建:Conv(输入F-UP,输出
I-SR
)
详细公式请看论文
(1)I-LR
(2)F0 = H-SF(I-LR)
(3)F-DF = H-RIR(F0)
(4)F-UP = H-UP(F-DF)
(5)I-SR = H-REC(F-UP) = H-RCAN(I-LR)
1.4.1、残差中的残差(RIR,Residual In Residual):由 G 个残差组(RG)和 1 条长跳跃连接(LSC)组成;每个RG由 B 个残差通道注意力块(RCAB)和 1 条短跳跃连接(SSC)组成;每个RCAB由 1 个通道注意力(CA)和 1 个残差块(RB)组成,每个RB具有 1 条短跳跃链接。
RIR的进阶史:
- 首先,提出了
残差组
,用于构建更深网络(但性能一般)
- 然后,在RIR中引入一条
长跳连接(LSC)
,用于稳定非常深的训练(不仅能够促进RG之间的信息流动,还能使RIR在粗层次上学习残差信息。):
- 为了进一步迈向残差学习,
在每个 RG 中堆叠 B个残差通道注意力
:
- 为了使主网络更关注更具信息量的特征,引入了一条
短跳连接(SSC)
:
结论:通过 LSC 和 SSC,在训练过程中更丰富的低频信息更容易被绕过。
1.4.2、通道注意力(CA,Channel Attention)
- 问题:当前 SR 方法对 LR 的
通道特征一视同仁
,这在实际应用中不够灵活。- 优化:引入了通道注意力(CA)机制,为每个通道特征生成不同的注意力权重
符号 × 表示逐元素乘积
- (1)
H x W x C
:设X = [x1,...,Xc-1,Xc]
为输入,其具有C
个特征图,每个特征图的大小为H x W
。- (2)全局池化函数
H-GP
:其提取得到的通道统计量(z =1 x 1 x C
)可以被视为局部描述符的集合。- (3)引入了一种门控机制
s = f[W-U * δ[W-D(z)]]
:其提取得到的聚合信息(s =1 x 1 x C
)充分捕捉了通道间的依赖关系。
- 其中:f 和 δ 分别表示 S 型门控函数和 ReLU 函数。
W-D
和W-U
是卷积层的权重集W-D
作用是按比例r
进行通道降维(1 x 1 x C/r
),经过 ReLU 激活后,W-U
作用是将低维信号按比例r
进行通道升维(1 x 1 x C)。
结论:通道注意力机制通过提取通道间的统计信息来进一步增强网络的判别能力。
1.4.3、残差通道注意力块(RCAB,Residual Channel Attention Blocks)
(见图4)将通道注意力(CA)集成到残差块(RB)中,并提出了残差通道注意力块(RCAB)。
符号 × 表示逐元素乘积
符号 + 表示逐元素累加
Fg,b-1
和Fg,b
分别是 RCAB 的输入和输出Xg,b
是通过两个堆叠的卷积层计算得到的残差部分
1.5、数据集 + 前处理
- 训练集:在 DIV2K 数据集中获取 800 幅图像
- 测试集:使用五个标准的基准数据集,包括Set5,Set14,B100,Urban100以及Manga109。
超参数
- HR - LR模拟:分别采用双三次插值(Bicubic Interpolation,BI)和模糊降采样(Blur-Downsampling,BD)对 HR 图像进行降质,并测试 RCAN 效果。
- 数据增强:随即旋转90°、180°、270°以及水平翻转
- 在每个批次中,提取 16 个大小为 48 x 48 的 LR 彩色块作为输入。
- ADAM优化器,β1 = 0.9,β2 = 0.999,ε = 10-8
- 初始化学习率 = 10-4,然后在每次反向传播的 2 x 105 次迭代后减半。
1.6、模型参数
- 在RIR结构中,RG数量设为
G = 10
,在每个RG中,将RCAB数量设为B = 20
。- 通道降维和通道升维中的卷积层的核大小为
1 x 1
,其余所有卷积层的核大小均为3 x 3
。- 滤波器
- 浅层特征提取和RIR结构中的卷积层具有
C = 64
个滤波器- 通道降维中的卷积层具有
Cr = 4
个滤波器,其中缩减比r = 16
- 对于升维模块,使用ESPCNN将低分辨率特征上采样为高分辨率特征。
1.7、研究结果
(基础)指标参数:PSNR + SSIM
在图像重建与增强任务中,评价模型性能通常采用 峰值信噪比(Peak Signal-to-Noise Ratio,PSNR)和结构相似性(Structural Similarity Index Measure,SSIM) 两个核心指标。
峰值信噪比(Peak Signal-to-Noise Ratio,PSNR)
:衡量HR(高分辨率原图)与SR(超分辨率恢复图像)之间的均方误差(MSE),是一种基于像素的传统图像质量评估指标。
- 计算公式
PSNR = 10 ⋅ log 10 ( MAX 2 MSE ) \text{PSNR} = 10 \cdot \log_{10} \left( \frac{\text{MAX}^2}{\text{MSE}} \right) PSNR=10⋅log10(MSEMAX2)
其中:
- MAX:是图像的最大像素值(通常对于8位图像,MAX=255)。
- MSE:是均方误差(Mean Squared Error),定义如下:
MSE = 1 m × n ∑ i = 1 m ∑ j = 1 n ( I HR ( i , j ) − I SR ( i , j ) ) 2 \text{MSE} = \frac{1}{m \times n} \sum_{i=1}^{m} \sum_{j=1}^{n} (I_{\text{HR}}(i,j) - I_{\text{SR}}(i,j))^2 MSE=m×n1i=1∑mj=1∑n(IHR(i,j)−ISR(i,j))2- 特点
- PSNR值越高,表示图像质量越高。
- 对极小的像素差异敏感,但对人眼感知的结构性失真不敏感。
- 适合评估轻微噪声、细粒度恢复效果,但无法准确反映感知质量。
结构相似性(Structural Similarity Index Measure,SSIM)
:从亮度、对比度和结构三个方面综合评价图像质量,是一种更符合人类视觉感知特性的指标。
- 计算公式:
SSIM ( x , y ) = ( 2 μ x μ y + C 1 ) ( 2 σ x y + C 2 ) ( μ x 2 + μ y 2 + C 1 ) ( σ x 2 + σ y 2 + C 2 ) \text{SSIM}(x, y) = \frac{(2\mu_x \mu_y + C_1)(2\sigma_{xy} + C_2)}{(\mu_x^2 + \mu_y^2 + C_1)(\sigma_x^2 + \sigma_y^2 + C_2)} SSIM(x,y)=(μx2+μy2+C1)(σx2+σy2+C2)(2μxμy+C1)(2σxy+C2)
其中:
- μx 和 μy:分别是HR和SR图像的均值
- σx2 和 σy2:分别是HR和SR图像的方差
- σxy:分别是HR和SR图像的协方差
- C1 和 C2:稳定因子,通常用于避免分母为零。
- 特点
- SSIM取值范围在0,1,数值越接近1表示图像越接近原图。
- 能有效捕捉亮度变化、对比度差异以及局部结构失真。
- 比PSNR更能反映人眼真实的主观感知质量。
(基础)降质方法:BI + BD
在 RCAN 及其相关的超分辨率任务中,降质(Degradation) 指的是如何从HR图像生成LR图像,以用于训练超分辨率模型。
备注:由于应用场景下的数据多用作HR,因此需要构建LR。若已经有了LR,则跳过该步骤。
常见的降质方法:
双三次插值(Bicubic Interpolation,BI)
:使用双三次插值直接将 HR 图像下采样到 LR。
- 优点:最常见的降质方式之一,简单快速,在许多超分辨率基准测试(如 DIV2K 数据集)中被广泛使用。
- 缺点:不会引入模糊和噪声,通常生成的 LR 图像质量较高,可能与真实降质(如相机成像)不同,导致模型泛化能力有限。
模糊-降采样(Blur-Downsampling,BD)
:先对 HR 图像进行高斯模糊(Gaussian Blur),再进行降采样(通常使用双线性插值 Bilinear Interpolation 或邻近插值 Nearest Neighbor Interpolation)。
- 优点:更接近真实降质情况,例如摄像头的光学模糊 + 传感器采样效应。
- 缺点:额外的模糊可能会增加超分辨率任务的难度,要求模型具备更强的去模糊能力。
(1)RIR 和 CA 的性能测试(对比表)
表中展示了
RIR(LSC、SSC)
和CA
在 Set5(2x) 数据集上的PSNR
值:
- 状态:×表示未启用,√表示启用
- 结论:
- (1)RIR影响:当单独或同时移除 LSC 和 SSC时,无论是否使用 CA,PSNR值都相对较低。这表明,仅仅堆叠残差块并不适用于构建非常深且强大的 SR 网络。
- (2)CA 影响:在前 4 列和后 4 列的结果中,启用 CA 比未启用 CA 的 PSNR 值更高。
(2)使用双三次插值(BI)退化模型的结果(对比图)
- 在 x4 放大倍数下的视觉对比
- 在 x8 放大倍数下的视觉对比
(3)使用模糊降采样(BD)退化模型的结果(对比图)
- 在 x3 放大倍数下的视觉对比
(4)性能与识别能力(对比表)
图像超分辨率(SR)也可作为高级视觉任务(例如,物体识别)的预处理步我们评估物体识别性能,以进一步证明 RCAN 的有效性。
- 使用相同的设置。
- 使用 ResNet-50 [11]作为评估模型,并使用 ImageNet CLS-LOC 验证数据集中的前 1000 张图像进行评估。原始裁剪的 224x224 图像用于基线,而超分辨率方法则将其缩小至 56x56。
- 使用 4种最先进的方法(例如 DRCN、FSRCNN、PSyCo和 ENet-E)对低分辨率图像进行上采样,然后计算它们的准确率。
如表4,RCAN 实现了最低的 top-1 和 top-5 错误率,进一步证明了 RCAN 具有极强的表征能力。
(5)性能与模型大小(对比图)
备注:更深的网络,将导致其权重参数的数量更多
1.8、RCAN恢复的图像具有一定的模糊性(自测经验 + 解决方案)
- github源码地址:https://github.com/yulunzhang/RCAN
在源码中 Figs 文件夹下,提供了多种不同超分辨率(SR)网络模型的对比结果。从实验结果来看,RCAN 在超分辨率重建任务中的表现最佳,但它也存在一个共同特点 —— RCAN恢复的图像具有一定的模糊性。
以下因素是设定在论文效果的基础上(如超参数等除外)
(1)主要因素:模型架构
- 模型架构限制:RCAN的残差结构可能在深层网络中丢失高频细节 —— 模型本身具有一定的局限性
- 通道注意力局限:通道注意力机制可能对某些频段特征关注不足 —— 模型本身具有一定的局限性
- 上采样方式:使用简单的插值上采样(如双三次)会引入平滑效应
(2)图像前处理:LR-HR 配对数据的质量
- 监督学习的本质:学习 LR → HR 的映射关系
- 尽管 LR 图像的高频信息很少,但 HR 图像中包含了丰富的高频细节为网络提供了恢复高频信息的依据。因此,在训练过程中,SR 网络通过成对的 LR-HR 训练数据,学习如何将 LR 图像的模糊细节映射到 HR 图像的高频细节。
- LR 图像:分辨率较低,图像中的细节和高频信息往往会丢失,呈现出模糊的效果。
- HR 图像:分辨率较高,包含了更多细节和高频信息,作为网络学习的目标。
① LR-HR 配准数据是否对齐(仅限于自提供数据对,而不是通过 HR 下采样得到 LR)
在处理医学影像 LR-HR(低分辨率-高分辨率)配准数据时,如果数据并非通过 HR 直接下采样得到 LR,而是由不同扫描条件(例如不同时间点、不同设备、不同拍摄角度等)获得的 LR-HR 对,通常会存在配准误差(像素偏差 = 0.5/1/2/3等等)。
误差通常来源于数据采集过程中的不对齐因素:
- 样本移动:由于医学影像的采集通常需要一定时间,样本可能会轻微移动,导致 LR 和 HR 影像不完全对齐。
- 相机或扫描设备抖动:在拍摄过程中,设备可能会发生轻微移动,使得 LR 和 HR 影像出现偏移。
- 检测 LR-HR 是否对齐:查看两个图层是否存在偏移(手动偏移校正)
- 处理方法:
- (1)
偏移校正
:若是平移变换则可以手动校正,否则采用图像配准(如:刚体变换/仿射变换/ B-Spline 变形)。- (2)
裁减边界区域
,以保证核心区域对齐良好。如边界存在黑边,尤其是LR-HR存在不对应的黑边。
② HR 分辨率质量(采用图像增强提升 HR 图像分辨率,可以显著提升超分效果)
在训练过程中,LR 图像和 HR 图像的质量直接影响到最终恢复效果,特别是 HR 图像的质量对于训练的影响很大。
高质量的 HR 图像
:当 训练集中 HR 图像本身的高频细节丰富、清晰时,网络能够学习到更准确的映射关系,从而在恢复过程中生成更精细的高频信息。
- 在生理影像领域,高质量的 HR 图像获取难度较高,因此通常会采用前处理方法(如去卷积、去噪、对比度增强等)来提升图像质量。
- 但要注意过度处理可能导致信息丢失或产生伪影,使得训练出的模型可能更适应 " 增强后的 HR " ,而非真实 HR,导致推理时超分效果下降。
低质量的 HR 图像
:若训练集中 HR 图像本身的高频信息不足(例如,图像拍摄质量较差或细节缺失),则网络可能会学到较为模糊的恢复策略,导致最终的超分辨率图像缺乏细节,或产生过于平滑的效果。
RCAN 实测结论:当 HR 分辨率越高时,则即使 LR 分辨率极低,得到的恢复图像仍然可以无限逼近于 HR 图像。
- 假设:在同一张图像上,通过不同的采样率(即采样步长 = 0.1,0.5,1,2,3)得到多个不同分辨率图像(image0.1,image0.5,image1,image2,image3)。
- 假设:分辨率水平由高到低:极高、较高、一般、较差
- SR任务:使用对比图像 HR = image0.1,LR = image2,将得到相对极高的恢复图像。
- SR任务:使用对比图像 HR = image0.1,LR = image3,将得到相对较高的恢复图像。
- SR任务:使用对比图像 HR = image0.5,LR = image2,将得到相对一般的恢复图像。
- SR任务:使用对比图像 HR = image0.5,LR = image3,将得到相对较差的恢复图像。
备注:若采样率受限或无法得到高分辨率HR图像,则可以通过图像前处理(去噪、频率滤波、图像增强等)提高 HR 图像的分辨率。如:将经过图像前处理的 image0.5当做 image0.1 使用。
实测方案:使用宽场滤波,显著提升图像分辨率
备注:由于分辨率提升非常明显,因此对于大部分的超分模型都通用。
③ 归一化方法(min和max使用固定值) + 切块预测(归一化的影响)
(1)图像归一化方法及使用建议 —— 使用固定归一化
(2)整图预测 + 切块预测(归一化的影响)
备注:在切块预测时,应避免对每个小块单独进行归一化操作,而是采用全图归一化后再切块。
(3)模型训练:不同损失函数对 SR 的影响
以下代码是 RCAN 源码中所提供的四种损失函数:
# "F:\AI\RCAN-master\RCAN_TrainCode\code\loss\__init__.py"
class Loss(nn.modules.loss._Loss):
def __init__(self, args, ckp):
super(Loss, self).__init__()
print('Preparing loss function:')
self.n_GPUs = args.n_GPUs
self.loss = []
self.loss_module = nn.ModuleList()
for loss in args.loss.split('+'):
weight, loss_type = loss.split('*')
if loss_type == 'MSE':
loss_function = nn.MSELoss()
elif loss_type == 'L1':
loss_function = nn.L1Loss()
elif loss_type.find('VGG') >= 0:
module = import_module('loss.vgg')
loss_function = getattr(module, 'VGG')(loss_type[3:], rgb_range=args.rgb_range)
elif loss_type.find('GAN') >= 0:
module = import_module('loss.adversarial')
loss_function = getattr(module, 'Adversarial')(args, loss_type)
损失类型 | 计算方式 | 适用范围 | 优缺点 |
---|---|---|---|
均方误差(MSE, Mean Squared Error) | 均方误差损失 | 超分辨率、去噪 | 计算稳定,但对大噪声更敏感,容易导致模糊(过度平滑) |
L1 损失(Mean Absolute Error, MAE) | 绝对误差损失 | 超分辨率、目标检测 | 保留细节,不会对大噪声过度惩罚,对异常值更鲁棒,但收敛速度较慢 |
VGG 感知损失(Perceptual Loss) | 基于VGG 网络的高层特征来计算 SR 图像与 HR 图像之间的差异。 | 超分辨率、风格迁移 | 更符合人眼感知,减少过度平滑,但计算较复杂 |
GAN 对抗损失(Adversarial Loss, GAN Loss) | 生成对抗网络损失 | 超分辨率、图像生成 | 生成高质量图像,减少过度平滑,但易产生伪影 |
为什么 MSE 对异常值(outlier)或大噪声敏感?
小误差权重较低
:(如0.1)平方后更小(0.01),这些区域可能优化不足。大误差权重较高
:(如 10)平方后更大(100),MSE 的优化目标是最小化整体损失,但由于大误差项贡献较大,模型训练时会更倾向于减少这些大误差(即过度拟合大噪声区域或异常值),而忽略小误差(即数据的整体趋势),导致训练不够鲁棒,或者网络在含噪数据上表现不佳。
- 在图像超分任务中,输出变得过度平滑,而不是恢复真实的高频细节(因为降低 MSE 最简单的方法是模糊化预测值,使其更接近整体均值)。
通常在 SR 任务中,
MSE/L1 + 感知损失 + GAN 损失
结合使用,确保高 PSNR 的同时提升视觉质量。
- Huber 损失(MSE + MAE 结合):小误差时使用 MSE 以平滑优化,大误差时使用 L1 以减少异常值影响。
- L1Loss + VGGLoss
- L1 + VGGLoss + GANLoss
class TotalLoss(nn.Module):
def __init__(self, lambda_l1=1.0, lambda_vgg=0.1, lambda_gan=0.01):
super(TotalLoss, self).__init__()
self.l1_loss = nn.L1Loss()
self.vgg_loss = PerceptualLoss(vgg_loss_weight=lambda_vgg)
self.gan_loss = nn.BCELoss() # 这里以二元交叉熵作为 GAN 损失示例
self.lambda_l1 = lambda_l1
self.lambda_vgg = lambda_vgg
self.lambda_gan = lambda_gan
def forward(self, pred, target, pred_fake, real_label):
loss_l1 = self.l1_loss(pred, target)
loss_vgg = self.vgg_loss(pred, target)
loss_gan = self.gan_loss(pred_fake, real_label) # GAN 判别器的损失
total_loss = (
self.lambda_l1 * loss_l1 +
self.lambda_vgg * loss_vgg +
self.lambda_gan * loss_gan
)
return total_loss, loss_l1, loss_vgg, loss_gan
(4)图像后处理
二、项目实战
2.1、虚拟环境
2.2、环境配置
方法一:RCAN开源代码:默认支持 PyTorch 1.2.0 —— 官方已移除(不建议)
- github源码地址:https://github.com/yulunzhang/RCAN
RCAN 基于 EDSR(PyTorch)构建,默认支持 PyTorch 1.2.0
,主要是 torch.utils.data.dataloader 限制了版本。- 在 Ubuntu 14.04/16.04 环境 (Python3.6、PyTorch0.4.0、CUDA8.0、cuDNN5.1) 和 Titan X/1080Ti/Xp GPU 上进行了测试。
备注1:PyTorch 1.2.0 版本已经被官方移除,无法通过 pip install 直接安装它(镜像源也找不到)。
备注2:PyTorch 1.2.0 只支持 CUDA 10.0,需要从官方旧版本存档下载并安装。
# EDSR 依赖库如下:
# Python 3.6
# PyTorch >= 1.0.0
# numpy
# skimage
# imageio
# matplotlib
# tqdm
# cv2 >= 3.xx (Only if you want to use video input/output)
##############################################################
"""(已放弃)尝试升级dataloader,已老实(比较麻烦)"""
# conda create -n RCAN39 -y
# conda activate RCAN39
# conda install python=3.9 -y
# 安装 PyTorch + CUDA 11.8 版本(若版本CPU,则去掉url)
# pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
# pip install numpy scikit-image imageio matplotlib tqdm opencv-python opencv-python-headless
##############################################################
"""(已放弃)心态已奔溃,降版本最麻烦的就是版本对应,最棘手的就是没有对应版本。"""
# conda create -n RCAN36 -y
# conda activate RCAN36
# conda install python=3.6 -y
# [轮子下载:cpu/torch-1.0.0-cp36-cp36m-win_amd64.whl](https://editor.csdn.net/md?articleId=146098069)
# pip install numpy scikit-image imageio matplotlib tqdm -i https://pypi.tuna.tsinghua.edu.cn/simple
# 备注:conda install
##############################################################
测试结果:
(1)不支持直接执行main.py,代码中有配置绑定,只能调用命令行。
(2)通过命令行执行模型预测:
- 【路径】(RCAN36) F:\AI\RCAN\RCAN-master\RCAN-master\RCAN_TestCode\code
- 【命令】
python main.py --data_test MyImage --scale 2 --model RCAN --n_resgroups 10 --n_resblocks 20 --n_feats 64 --pre_train ../model/RCAN_BIX2.pt --test_only --save_results --chop --save 'RCAN' --testpath ../LR/LRBI --testset Set5
方法二:basicSR开源代码:是一个集成了多种超分模型的工具 —— 建议
BasicSR (Basic Super Restoration)
:是一个基于 PyTorch 的图像 / 视频超分复原增强工具箱(比如超分辨率, 去噪, 去模糊, 去 JPEG 压缩噪声等),旨在简化和加速任务的开发。它提供了多种超分辨率模型,包括经典和最新的深度学习模型,并且具有高效的训练和推理接口。BasicSR 提供了许多用于超分辨率、图像恢复等任务的工具,适用于研究人员和开发人员。
- github源码地址:https://github.com/XPixelGroup/BasicSR
备注:点击对应链接,查看相应内容。
conda create -n BasicSR37 -y
conda activate BasicSR37
conda install python=3.7 -y
###########################################################################
# (1)安装必需文件
cd /d F:\AI\RCAN\BasicSR-master
pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
###########################################################################
"""requirements.txt(以下是源码中修改后的内容)"""
addict
future
lmdb
numpy
opencv-python
Pillow
pyyaml
recommonmark
requests
scikit-image
scipy
sphinx
sphinx_intl
sphinx_markdown_tables
sphinx_rtd_theme
# tb-nightly 安装失败
# torch>=1.7 # 单独安装GPU版本
# torchvision # 单独安装GPU版本
tqdm
yapf
###########################################################################
# (2)以下是模型训练时需要额外安装的库,不包含在requirements.txt里,所以单独列出。
pip install six tensorboard -i https://pypi.tuna.tsinghua.edu.cn/simple
###########################################################################
# (3)安装GPU版本的torch,CPU版本的蜗牛速度让人头大。(注意CUDA版本)
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu113
# 根据电脑配置和CUDA版本安装对应的Torch(亲测有效,python=3.9仍然可以安装成功)
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
import torch
print(torch.cuda.is_available()) # 判断GPU是否可用
2.3、basicSR:RCAN 网络模型 —— 用于理解模型架构,可跳过
在 BasicSR 代码库中:
basicsr/archs/
目录(rcan_arch.py):主要存放各种神经网络的架构(architecture)basicsr/models/
目录(sr_model.py):存放的是模型的整体逻辑,包括如何加载、训练和推理。
import torch
from torch import nn as nn
from basicsr.utils.registry import ARCH_REGISTRY
from .arch_util import Upsample, make_layer
class ChannelAttention(nn.Module):
"""Channel attention used in RCAN. ———— 该模块用于计算通道注意力权重,增强特征表达能力。
参数:
num_feat (int): 中间特征的通道数(即输入通道数)。
squeeze_factor (int): 用于通道压缩的因子(默认为 16)。
"""
def __init__(self, num_feat, squeeze_factor=16):
super(ChannelAttention, self).__init__()
self.attention = nn.Sequential(
nn.AdaptiveAvgPool2d(1), # 全局平均池化,将每个通道的特征图转换为一个标量
nn.Conv2d(num_feat, num_feat // squeeze_factor, 1, padding=0), # 1x1 卷积: 降维(通道数减少)
nn.ReLU(inplace=True), # ReLU 激活函数
nn.Conv2d(num_feat // squeeze_factor, num_feat, 1, padding=0), # 1x1 卷积: 升维(通道数恢复)
nn.Sigmoid() # 归一化权重(0~1)
)
def forward(self, x):
y = self.attention(x) # 计算通道权重
return x * y # 原特征乘上注意力权重 ———— 增强关键通道特征,抑制无用信息。
class RCAB(nn.Module):
"""Residual Channel Attention Block (RCAB) used in RCAN. ———— 该模块在标准的残差块中加入通道注意力机制,以增强特征提取能力。
参数:
num_feat (int): 中间特征的通道数。
squeeze_factor (int): 通道压缩因子(默认为 16)。
res_scale (float): 残差缩放比例,控制残差贡献的大小(默认为 1)。
"""
def __init__(self, num_feat, squeeze_factor=16, res_scale=1):
super(RCAB, self).__init__()
self.res_scale = res_scale # 记录残差缩放比例
self.rcab = nn.Sequential(
nn.Conv2d(num_feat, num_feat, 3, 1, 1), # 3x3 卷积,保持通道数不变
nn.ReLU(True), # ReLU 激活函数
nn.Conv2d(num_feat, num_feat, 3, 1, 1), # 3x3 卷积
ChannelAttention(num_feat, squeeze_factor) # 通道注意力模块
)
def forward(self, x):
res = self.rcab(x) * self.res_scale # 计算残差
return res + x # 残差连接 ———— 让信息直接传递,避免梯度消失问题。
class ResidualGroup(nn.Module):
"""Residual Group of RCAB. ———— 该模块由多个 RCAB 组成,并通过一个额外的 3x3 卷积进行特征融合。
参数:
num_feat (int): 中间特征的通道数。
num_block (int): 该残差组中的 RCAB 数量。
squeeze_factor (int): 通道压缩因子(默认为 16)。
res_scale (float): 残差缩放比例(默认为 1)。
"""
def __init__(self, num_feat, num_block, squeeze_factor=16, res_scale=1):
super(ResidualGroup, self).__init__()
self.residual_group = make_layer(RCAB,
num_block,
num_feat=num_feat,
squeeze_factor=squeeze_factor,
res_scale=res_scale) # 多个 RCAB 组成一个 Residual Group。
self.conv = nn.Conv2d(num_feat, num_feat, 3, 1, 1) # 额外增加一个 3x3 卷积(用于融合 RCAB 提取的特征)
def forward(self, x):
res = self.conv(self.residual_group(x)) # 计算残差
return res + x # 残差连接
#####################################################################
# RCAN 采用深度残差学习 + 通道注意力机制来增强超分辨率性能。
# 主要特点:
# RCAB(残差通道注意力块):两层卷积 + 通道注意力。
# ResidualGroup(残差组):多个 RCAB 组成一个较大的单元,增强学习能力。
# 全局残差连接:加速训练,提高稳定性。
# Upsample 上采样:将低分辨率图像变成高分辨率。
#####################################################################
@ARCH_REGISTRY.register()
class RCAN(nn.Module):
"""Residual Channel Attention Networks. ———— 该网络用于超分辨率任务,在多个残差组中引入通道注意力机制。
论文: 《Image Super-Resolution Using Very Deep Residual Channel Attention Networks》
源码: https://github.com/yulunzhang/RCAN
参数:
num_in_ch (int): 输入通道数(例如 RGB 图像为 3)。
num_out_ch (int): 输出通道数(通常与输入通道数相同)。
num_feat (int): 中间特征的通道数(默认为 64)。
num_group (int): 残差组(Residual Group)的数量(默认为 10)。
num_block (int): 每个残差组中的 RCAB 数量(默认为 16)。
squeeze_factor (int): 通道压缩因子(默认为 16)。
upscale (int): 超分辨率放大倍数(支持 2^n 和 3,默认为 4)。
res_scale (float): 残差缩放比例(默认为 1)。
img_range (float): 图像数值范围(默认为 255)。
rgb_mean (tuple[float]): 计算均值标准化时的 RGB 均值,默认使用 DIV2K 数据集的均值。
"""
def __init__(self,
num_in_ch,
num_out_ch,
num_feat=64,
num_group=10,
num_block=16,
squeeze_factor=16,
upscale=4,
res_scale=1,
img_range=255.,
rgb_mean=(0.4488, 0.4371, 0.4040)):
super(RCAN, self).__init__()
self.img_range = img_range # 输入图像数值范围
self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) # 计算图像均值
##############################################################################
self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1) # 输入卷积层
self.body = make_layer(ResidualGroup,
num_group,
num_feat=num_feat,
num_block=num_block,
squeeze_factor=squeeze_factor,
res_scale=res_scale) # 主体网络,由多个残差组组成
self.conv_after_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1) # 额外的 3x3 卷积层(用于特征融合)
self.upsample = Upsample(upscale, num_feat) # 上采样模块(用于提高分辨率)
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) # 输出卷积层
##############################################################################
def forward(self, x):
self.mean = self.mean.type_as(x)
# 归一化输入
x = (x - self.mean) * self.img_range
##############################################################################
# 通过网络计算超分辨率结果
x = self.conv_first(x) # 初始卷积
res = self.conv_after_body(self.body(x)) # 主体网络
res += x # 全局残差连接
x = self.conv_last(self.upsample(res)) # 上采样 + 输出
##############################################################################
# 反归一化输出
x = x / self.img_range + self.mean
return x
2.4、basicSR:模型训练 + 模型测试
- 数据集准备:HR 和 LR 数据集中的每个图像名称都必须一一对应。
- 备注:RCAN源码路径下有对应的数据,且提供数据集下载。
- 运行命令(训练命令或测试命令)
- 所有命令都在 BasicSR-master 根目录下运行,若是在 BasicSR-master / basicsr / 目录下运行,将提示找不到 basicsr 模块。其中:basicsr 是源码库,但 BasicSR-master 中包含更多的配置。
- 在 BasicSR-master / options 文件夹下,修改配置文件(.yml),详细请参考《配置说明(训练+测试)》
以下是配置文件中的常用参数,其余参数几乎不变:
结果保存路径(name:BasicSR-master / experiments / name)
超分辨率模型(model_type:SRModel)
设备类型(num_gpu:0 for cpu mode)
训练集路径(train - dataroot_gt + dataroot_lq)
测试集路径(val - dataroot_gt + dataroot_lq)
网络模型 - 参数配置(network_g)
是否启用迁移学习(path - pretrain_network_g) —— 若启用,则指定pt文件路径
保存模型的频率(logger - save_checkpoint_freq) —— 表示每迭代N次保存一个模型
模型训练 - 参数配置(train):total_iter(总迭代次数)、loss type(MSE / L1)
模型验证 - 参数配置(val)
备注:文件路径均不可以添加引号(单或双)
"""
模型训练:python -m basicsr.train -opt options/train/RCAN/train_RCAN_x2.yml
模型测试:python -m basicsr.test -opt options/test/RCAN/test_RCAN.yml
"""
# 以下是一些调试代码(分为三个部分):
##########################################################################
# train.py
import warnings
warnings.filterwarnings("ignore", category=RuntimeWarning)
##########################################################################
# base_model.py - load_network()
if param_key in load_net:
load_net = load_net[param_key] # 原来的逻辑
else:
load_net = load_net # 直接加载整个字典
##########################################################################
# train.py
print(f"epochs = {epoch}/{total_epochs}, iters = {current_iter}/{total_iters}")
(1)配置文件(参数详解):train_RCAN_x2.yml
不可以出现中文
# general settings: Chinese cannot appear
name: 201_RCANx2_scratch_DIV2K_rand0 # name your experiment
model_type: SRModel
scale: 2 # 2x SR
num_gpu: 1 # set num_gpu: 0 for cpu mode
manual_seed: 10 # fix random seed
# dataset and data loader settings
datasets:
train: # training datasets
name: DIV2K # name of the training dataset
type: PairedImageDataset # options: [SingleImageDataset, PairedImageDataset]
dataroot_gt: ./datasets/val_set5/Set5 # GT dataset
dataroot_lq: ./datasets/val_set5/Set5_bicLRx4 # LR dataset
filename_tmpl: '{}' # file name template
io_backend: # options: [lmdb, disk]
type: disk
# (for lmdb)
# type: lmdb
gt_size: 96 # GT size
use_hflip: true # data augmentation policy
use_rot: true # data augmentation policy
# data loader
num_worker_per_gpu: 6 # number of workers per GPU
batch_size_per_gpu: 16 # total batch size
dataset_enlarge_ratio: 100 # enlarge the size of each image in the dataset.
prefetch_mode: ~
val: # validation datasets
name: Set5 # name of the validation dataset
type: PairedImageDataset # options: [SingleImageDataset, PairedImageDataset]
dataroot_gt: ./datasets/val_set5/Set5 # GT dataset
dataroot_lq: ./datasets/val_set5/Set5_bicLRx4 # LR dataset
io_backend: # options: [lmdb, disk]
type: disk
# network structures
network_g:
type: RCAN # options: [RCAN, RCAN_x2, RCAN_x4]
num_in_ch: 3 # number of input channels —— gray_image=1, RGB_image=3
num_out_ch: 3 # number of output channels —— gray_image=1, RGB_image=3
num_feat: 64 # number of feature maps
num_group: 10 # number of residual groups
num_block: 20 # number of residual blocks
squeeze_factor: 16 # squeeze factor of channel
upscale: 2 # upsample ratio —— Channel Attention
res_scale: 1 # residual scaling
img_range: 255. # range of input image values
rgb_mean: [0.4488, 0.4371, 0.4040] # image mean
# path
path:
pretrain_network_g: ~ # path to pre-trained model
strict_load_g: true # whether to strictly load the pre-trained model
resume_state: ~ # path to the checkpoint to resume training
# training settings
train: # options: [SRTrain, SRTest]
ema_decay: 0.999 # decay factor for exponential moving average
optim_g: # options: [Adam, SGD]
type: Adam # options: [Adam, SGD]
lr: !!float 1e-4 # initial learning rate for Adam
weight_decay: 0 # weight decay
betas: [0.9, 0.99] # betas for Adam
scheduler:
type: MultiStepLR # options: [MultiStepLR, CosineAnnealingLR]
milestones: [200000] # list of epoch indices
gamma: 0.5 # decay factor for MultiStepLR
total_iter: 300000 # total number of training iterations
warmup_iter: -1 # number of iterations for warm up
# losses
pixel_opt:
type: L1Loss # options: [L1Loss, MSELoss]
loss_weight: 1.0 # loss weight for L1 loss
reduction: mean # reduction method for L1 loss
# validation settings
val: # options: [SRTest, SRTestVid4]
val_freq: !!float 5e3 # validation frequency
save_img: false # if save_img=True, the visualized results will be saved.
metrics:
psnr: # metric name, can be arbitrary
type: calculate_psnr # options: [calculate_psnr, calculate_ssim]
crop_border: 2 # crop border when evaluation
test_y_channel: false # test on Y channel of YCbCr
# logging settings
logger: # options: [Logger, TensorboardLogger]
print_freq: 100 # print frequency
save_checkpoint_freq: !!float 5e3 # frequency of saving checkpoints
use_tb_logger: true # whether to use tensorboard logger
wandb: # options: [None, WandbLogger]
project: ~ # project name of wandb
resume_id: ~ # resume id of wandb
# dist training settings
dist_params: # options: [None, distributed, ddp_cpu]
backend: nccl # options: [nccl, gloo]
port: 29500 # port can also be set in the command line
(2)配置文件(参数详解):test_RCAN.yml
不可以出现中文
# general settings: Chinese cannot appear
name: RCAN_BIX4-official # model name
model_type: SRModel # model type
scale: 2 # scale factor, only from {2, 3, 4}
num_gpu: 1 # number of gpus to use, 0 for CPU
manual_seed: 10 # manually set random seed
datasets:
test_1: # the 1st test dataset
name: val_set5 # dataset name
type: SingleImageDataset # options: [SingleImageDataset, PairedImageDataset]
dataroot_lq: ./datasets/val_set5/Set5_bicLRx4
io_backend: # options: [lmdb, disk]
type: disk
# test_2: # the 2nd test dataset
# name: val_set14
# type: PairedImageDataset
# dataroot_gt: ./datasets/val_set14/Set14
# dataroot_lq: ./datasets/val_set14/Set14_bicLRx4
# io_backend:
# type: disk
# network structures
network_g:
type: RCAN # options: [RCAN, RCAN_x2, RCAN_x4]
num_in_ch: 3 # number of input channels —— gray_image=1, RGB_image=3
num_out_ch: 3 # number of output channels —— gray_image=1, RGB_image=3
num_feat: 64 # number of feature maps
num_group: 10 # number of residual groups
num_block: 20 # number of residual blocks
squeeze_factor: 16 # squeeze factor of channel
upscale: 2 # upsample ratio —— Channel Attention
res_scale: 1 # residual scaling
img_range: 255. # range of input image values
rgb_mean: [0.4488, 0.4371, 0.4040] # image mean
# validation settings
val:
val_freq: !!float 5e3 # validation frequency
save_img: true # if save_img=True, the visualized results will be saved.
suffix: "_suffix" # suffix of the saved images
# If you need to calculate PSNR and SSIM, you need to give both LR and HR paths
# metrics:
# psnr: # metric name, can be arbitrary
# type: calculate_psnr # options: [calculate_psnr, calculate_ssim]
# crop_border: 0 # crop border when evaluate
# test_y_channel: false # test on Y channel of YCbCr
# path
path:
pretrain_network_g: F:\AI\RCAN\BasicSR-master\experiments\201_RCANx2_scratch_DIV2K_rand0\models\net_g_50.pth # path to pre-trained model
strict_load_g: true # whether to strictly load the pre-trained model
resume_state: ~ # path to the checkpoint to resume training
2.5、basicSR:结果展示
- 模型训练:
- 预训练模型:保存于
BasicSR-master / experiments
,文件夹名称由 yml 的 name 参数决定。
- 测试结果:保存于
BasicSR-master / results
,文件夹名称由 yml 的 name 参数决定。
2.6、basicSR:剪枝 —— 加速推理
剪枝目标:
(1)减少参数量(冗余权重)
(2)减少计算量(FLOPs 下降,加速推理)
(3)尽可能保留重建性能(避免 PSNR 下降过多)
剪枝方法 | 影响参数 | 影响内容 |
---|---|---|
减少残差块(B) | num_block | 降低局部特征增强能力,减少 3×3 卷积次数 |
减少残差组(G) | num_group | 降低网络深度,减少 3×3 卷积次数 |
减少通道数 | num_feat | 降低所有 3×3 卷积计算量,减少参数 |
- 推荐顺序:先剪 B(RCAB),再剪 G(RG),最后再做通道数剪枝。
- (1)G(RG 组数)比 B(RCAB 数量)影响更大,剪 B 时影响较小,而剪 G 需要谨慎。
- (2)如果任务对 PSNR 要求高(如医学影像),建议少剪 G,多剪 B,保持全局一致性!
- 经验总结:
- (1)
剪 B(RCAB 块数)
:影响局部特征增强,适度减少不会显著影响 PSNR,但过度减少会影响边缘细节恢复。如:16 -> 8 -> 4
- (2)
剪 G(RG 组数)
:影响全局建模能力,减少过多会导致 PSNR 下降明显,影响整体色彩、纹理一致性。如:20 -> 10
- (3)
剪通道数(num_feat)
:计算量下降最多,但会影响注意力机制。如:64 -> 48 -> 32
- 备注:建议所有剪枝都进行 Finetune(微调),以恢复精度! —— (1)剪 B 可以直接预测,但 Finetune 会更稳定;(2)剪 G 影响较大,建议 Finetune;(3)剪通道数必须重新训练或微调,否则直接加载原模型会出错(维度不匹配)。
- (1)
network_g:
type: RCAN # options: [RCAN, RCAN_x2, RCAN_x4]
num_in_ch: 3 # number of input channels —— gray_image=1, RGB_image=3
num_out_ch: 3 # number of output channels —— gray_image=1, RGB_image=3
num_feat: 64 # number of feature maps
num_group: 10 # number of residual groups
num_block: 20 # number of residual blocks
(1)影响和经验:残差组数(num_group) + 残差块数(num_block)
在 RCAN 等超分网络中,
num_group
和num_block
直接影响网络的深度和建模能力,调整这些参数可用于剪枝,以减少计算量(FLOPs)并加速推理。
- (1) num_group(残差组数, RG)
- 作用:决定网络的全局深度,控制 残差组(Residual Group, RG) 的数量,每个 RG 内部包含多个 RCAB(Residual Channel Attention Block)。
- 影响:
- 减少 RG(num_group 变小) → 降低建模能力,影响 全局特征一致性,可能导致超分输出的整体色彩、纹理发生偏差。
- 增加 RG(num_group 变大) → 提高网络深度,增强全局信息建模能力,但增加计算量。
- (2) num_block(残差块数, RCAB 数量)
- 作用:决定每个 RG 内部的深度,即 每个残差组中的 RCAB(Residual Channel Attention Block)数量。
- 影响:
- 减少 RCAB(num_block 变小) → 局部特征增强能力下降,可能影响局部细节恢复(如边缘、纹理)。
- 增加 RCAB(num_block 变大) → 提高细节恢复能力,但计算量和显存需求增加。
剪枝策略 | FLOPs降低 | 推理加速 | PSNR 变化 | 影响 |
---|---|---|---|---|
B=20 → B=10 | 40% | 1.2× | ≈ -0.05 dB | 优先剪 B,影响较小 |
G=20 → G=10 | 50% | 1.5× | ≈ -0.1 dB | 适当减少 G,避免过度 |
G=20, B=20 → G=10, B=10 | 75% | 2× | ≈ -0.15 dB | dB 影响较大,但加速明显 |
(2)影响和经验:特征通道数(num_feat)
num_feat
是超分网络(如 RCAN)剪枝中的核心参数,影响多个关键计算模块。
- (1) 影响每个卷积层的输出通道数
- 作用:num_feat 决定了所有 3×3 卷积的输出通道数,影响整个网络的特征表达能力。
- 影响:降低 num_feat 会减少计算量(FLOPs 降低),但会削弱模型的特征提取能力,导致细节恢复能力下降。
- (2) 影响 RCAB 块(Residual Channel Attention Block, 残差通道注意力块)
- 作用:RCAB 负责局部特征增强,内部的 SE(Squeeze-and-Excitation)注意力模块需要用 num_feat 进行 1×1 卷积变换。
- 影响:
- num_feat 变小 → SE 模块的通道注意力能力下降 → 全局亮度、色彩可能受到影响
- num_feat 变大 → 网络容量增加,但计算量和显存需求也会上升
- (3) 影响全局残差连接过程(Long-Range Residual Learning)
- 作用:RCAN 使用 全局残差连接(GRC),在输入和输出之间建立直接的通道映射(即 F(x) + x)。
- 影响:
- 降低 num_feat → 影响全局信息传递,可能导致超分一致性变差
- 保持适中 num_feat → 平衡计算量和恢复质量
num_feat 剪枝方案 | FLOPs 下降 | 推理加速 | PSNR 影响 | 建议 |
---|---|---|---|---|
num_feat 剪枝 25% | 25% | 1.2× | 轻微下降 | 适用于轻量级优化 |
num_feat 剪枝 50% | 50% | 1.5× | 明显下降 | 适用于实时优化 |
num_feat 剪枝 75% | 75% | 2× | PSNR 下降严重 | 不推荐,影响较大 |
2.7、basicSR:通量 —— MB/s
模型预测有两种模式:
- 整图预测(默认):支持不同尺寸(1920 x 1920)、(1920 x 500)
- 切块预测(自定义):通常是将整张图像裁剪为多个相同尺寸的小块(128 x 128),然后将预测结果按切块顺序拼接为整张图像。
(1)basicSR:参数测试
通量测试:用于测试模型本省的性能。因此,只计算通量,不考虑效果。
- 剪枝参数:5 + 3 + 32 分别表示 G + B + 通道数
- 通量参数:剪枝参数 + 切块尺寸 + 切块数量…等
通量测试(以下是自定义的切块预测)
剪枝参数
切块尺寸
切块数量
(2)basicSR:TensorRT加速
📌 影响 TensorRT 加速的关键因素
- 模型结构
- 深度较浅的模型(如 ResNet-18)加速较少,深度较深的(如 RCAN)加速更明显。
- 3×3 卷积较多的模型(如 U-Net)加速比更大。
- GPU 硬件
- 高端 GPU(如 A100, RTX 4090) 支持更快的 Tensor Core 计算,INT8 提速更明显。
- 低端 GPU(如 GTX 1650) FP16 加速有限,甚至可能比 FP32 还慢。
- 输入分辨率
- 大分辨率图像(如 4K 超分)计算量大,TensorRT 优化效果更明显。
- 小尺寸输入(如 64×64)可能无法充分利用 Tensor Core,影响加速比。
- Batch Size
- 批量推理(batch > 1),TensorRT 并行计算优化效果更显著。
- 单张推理(batch = 1),加速比可能不如预期。
假设 Python 版推理通量是
1.01 MB/s
,那么:
- 直接使用 TensorRT C++(FP32) → 可达
3× ~ 5× → 3~5 MB/s
- 启用 FP16 计算 → 可达
4.5× ~ 10× → 4.5~10 MB/s
- 启用 INT8 计算(但有精度损失) → 可达
6× ~ 15× → 6~15 MB/s
✌️✌️✌️GoodLuck!✍️✍️✍️