论文笔记 |【CVPR2021】Uformer: A General U-Shaped Transformer for Image Restoration

本文介绍了Uformer,一种结合U-Net和Transformer结构的图像恢复模型。Uformer通过LeWinTransformer块降低了Transformer的计算复杂度,同时利用多尺度恢复调制器增强细节恢复。实验表明,该模型在图像去模糊和去噪等任务中表现出色,提升了图像修复的质量,并达到SOTA水平。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

论文笔记 |【CVPR2021】Uformer: A General U-Shaped Transformer for Image Restoration


在这里插入图片描述

1 Motivation

  • CNN无法捕捉长距离的像素依赖
  • Transformer计算复杂度高达二次方、获取局部依赖关系受限

解决方法: 提出Uformer

  • 参考了U-Net和transformer的结构
  • 将U-Net中的卷积层用transformer代替
  • 编码器-解码器结构和skip-connection保持不变。

核心设计:

  • LeWin Transformer块:
    非重叠窗口的自注意力—降低计算复杂度+捕获更多全局依赖关系
    前向传播的网络,两个全连接层之间引入深度卷积层—更好捕获局部上下文
  • 可学习、多尺度、恢复调制器:处理图像退化,恢复更多细节。可以灵活用于不同框架下的不同图像恢复任务(general)

在这里插入图片描述
贡献: Uformer+恢复调制器提升修复质量+sota
在这里插入图片描述

2 Method

在这里插入图片描述

2.1 整体结构

编码器:

  • 输入图像I,input projection提取底层特征
  • 3×3卷积
  • LeakyReLU,得到特征图X。
    在这里插入图片描述
  • 通过K个模块,每个模块包含一个LeWin transformer块和一个下采样层。

LeWin transformer块

  • 下采样层:把transformer输出的一维向量reshape为二维特征,再进行正常下采样
  • 步长为2的4×4卷积核:通道数加倍;分辨率HW都减半
  • 最后一层:多个LeWin块:捕获远程信息,如果卷积核尺寸和特征图大小相同,甚至捕获到全局信息。
    在这里插入图片描述

解码器:

  • K个模块,每个模块对应一个上采样层和一个LeWin transformer块。
  • 上采样层:步长为2的转置卷积,卷积核尺寸为2×2:分辨率加倍,通道数减半
  • LeWin transformer块传入:经过上采样层处理的特征图+编码器对应层的特征图(跳跃连接)+Modulartors传入
  • 最后:输出的一维向量,reshape为二维,使用3×3的卷积层获得与输入图像尺寸相同的残差图R,与输入图像相加,输出

文中K=4;损失函数:
在这里插入图片描述
L1和L2损失会造成图像过于平滑的问题,缺乏感官上的照片真实感;
效果的猜测:常数的存在,避免了梯度消失(接近0 的点);
开方避免梯度爆炸?

2.2 LeWin transformer Block

trnasformer存在问题:

  • 计算复杂度(可以计算所有token之间的全局自注意力,计算复杂度是token数量的二次方,不适用于高分辨率任务)

  • trnasformer在获取局部信息存在局限性
    在这里插入图片描述
    在这里插入图片描述
    LN:Layer Norm

  • W-MSA: non-overlapping Window-based Multi-head Self-Attention
    · 二维图像X,通道数为C,HW
    · 将X切割为多个尺寸为M×M的窗,窗与窗之间不重叠
    · 将每个窗的图像拉伸为一维,这样每个窗内的图像区域就成了M平方的向量
    · 每个窗内的这些向量用自注意力机制进行处理
    在这里插入图片描述
    · 计算第k个头的自注意力。
    · 将k个头的输出拼接,线性投影,最终W-MSA输出。
    · 完整注意力机制:(B,相对位置偏置)
    在这里插入图片描述
    效果:降低计算复杂度
    在这里插入图片描述
    分窗口,分层,低分辨率上的自注意力机制可以获得很大的感受野,能学习远程信息

  • LeFF: Locally-enhanced Feed-Forward Network 局部增强前馈网络
    在这里插入图片描述
    · transformer中的前馈网络在捕捉局部上下文信息具有局限性,因此加入深度卷积
    · 1×1卷积核,处理每个token增加通道数量
    · 将token重新变为二维特征图
    · 3×3的卷积捕获局部信息
    · Img2Tokens将二维特征图拉回一维token
    · 1× 1卷积降低通道数量使其和输入前的通道数量相同,叠加。
    · 激活函数:GELU

2.3 Multi-Scale Restoration Modulator 多尺度恢复调制器

在这里插入图片描述
在这里插入图片描述

  • 调制器:M ×M ×C 的可学习张量,在自注意模块之前添加到所有非重叠窗口。
  • 引入了边际额外参数和计算成本,但是很小。有助于以很少的计算成本恢复细节。
  • 有效性:图像去模糊和图像去噪。

可视化:
在这里插入图片描述
可能的解释:在解码器的每个阶段添加调制器可以灵活调整特征图,从而提高恢复细节的性能。

3 Experiment

在这里插入图片描述
· 动量项为(0:9; 0:999)
· 权重衰减为0.02
· 余弦衰减策略
· 用PSNR和SSIM结构相似性作为评价指标,除了去雨任务,其他任务都是在RGB色彩空间进行指标计算。去雨在YCbCr色彩空间的Y通道计算指标
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

4 消融实验

在这里插入图片描述
在这里插入图片描述 Transformer与CNN的对比: 将LeWin Transformer块中的用基于CNN的ResBlock代替
层次结构的多尺度与单一尺度的对比: 基于ViT的单一尺度模型用在降噪,包含两个卷积层的头用于特征提取,一个包含两个卷积层的尾用于输出,头和尾之间设置了12个transformer块
在这里插入图片描述
Locally-enhanced在各个位置的效果: 不用局部增强/放在自注意力部分/都放
结论: 放在前馈网络处效果最好
GMACs:乘法累加操作次数
在这里插入图片描述
多尺度恢复调制器:
在这里插入图片描述

5 Conclusion

提出Uformer

  • LeWin Transformer块:
    · 非重叠窗口的自注意力—降低计算复杂度+捕获更多全局依赖关系
    · 前向传播的网络,两个全连接层之间引入深度卷积层—更好捕获局部上下文
    · 可学习、多尺度、恢复调制器Modulator:处理图像退化,恢复更多细节。

6 Addition

  1. 窗口移位的作用 消融实验 0.01??
    在这里插入图片描述
  2. 跳跃连接的变体:设计了三种不同的跳过连接机制

在这里插入图片描述
在这里插入图片描述

  • Concat Skip:UNet广泛运用。将第 l 阶段的编码特征 El 和(K-l+1) 解码器阶段的特征 DK-l+1 逐通道拼接。K 是编码器/解码器阶段的数量。拼接特征馈送到解码器阶段的第一个 LeWin 块的 W-MSA 模块
  • Cross Skip:参考Transformer中的解码结构。在第一个LeWin模块中添加一个额外的注意力模块,它用于从解码特征D中寻找自相似像素信息,而第二个注意力则从解码特征中提取key与values,从第一个注意力模块输出中提取queries。
  • Concat Cross Skip:组合前两种。编码特征提取v k,解码特征提取queries。
    在这里插入图片描述
    效果: Concat最好。(正常浮动?原本在正文,后改到附录)GMACs:计算量

其他实验部分:
莫尔条纹的移除 效果不错

在这里插入图片描述

import time import re import os import matplotlib.pyplot as plt from lr_scheduler import LrScheduler from multiprocessing import Process import torch import torch.optim as optim from torch.utils.data import DataLoader from dataloader.data_UFormer import TrainingDataset # 你需要自定义你的数据集类 from model.UFormer import UFormer import yaml from criterions.criterionForUFormer import SISNRPlusCosSim import visdom import gc from torch.cuda.amp import autocast, GradScaler def save_kwargs_to_yaml(filename, **kwargs): with open(filename, 'w') as file: yaml.dump(kwargs, file) def load_kwargs_from_yaml(filename): with open(filename, 'r') as file: kwargs = yaml.safe_load(file) return kwargs def save_mean_gradient_plot(param_names, mean_grad_values, epoch, png_dir): """ 保存当前epoch的平均梯度图。 参数: param_names (list): 参数名称列表。 mean_grad_values (list): 对应参数的平均梯度值列表。 epoch (int): 当前的epoch数。 png_dir (str): 保存PNG文件的目录。 """ # 新建图表 plt.figure(figsize=(5120 / 300, 2880 / 300), dpi=300) # 绘制柱状图 plt.bar(param_names, mean_grad_values) # 设置图表标签 plt.xlabel('Parameter Name') plt.ylabel('Mean Gradient Magnitude') plt.xticks(rotation=90) # 旋转x轴标签以方便查看 # 设置图表标题 plt.title(f'Epoch {epoch + 1} Mean Gradient Magnitude per Parameter') # 调整图表布局 plt.tight_layout() # 构建保存路径 save_path = os.path.join(png_dir, f'gradient_plot_epoch_{epoch + 1}.png') # 保存图表为PNG文件 plt.savefig(save_path) # 清除当前图表 plt.close() def worker(): kwargs = load_kwargs_from_yaml("../config/kwargs_UFormer.yaml") N1 = kwargs['encoder_dimension'] L = kwargs['encoder_win_length'] transformer_seg_len = kwargs['transformer_seg_len'] H = kwargs['multihead'] clip_value = kwargs['clip_value'] lr = kwargs['learning_rate'] wd = kwargs['weight_decay'] wre = kwargs['warm_up_epochs'] num_epochs = kwargs['epochs'] seed = kwargs['seed'] bs = kwargs['batchsize'] de =
03-12
<think>好的,我现在需要帮助用户生成一个使用PyTorch训练UFormer模型的脚本,包含数据加载器、学习率调度器和梯度可视化。首先,我得回忆一下PyTorch的基本结构,确保模型定义、数据加载、训练循环这些部分都包含进去。用户提到的UFormer模型可能是一个自定义的模型,可能需要先定义它的结构,但用户没有提供细节,所以可能需要假设一个简单的Transformer结构或者使用现有的实现。不过这里可能需要用户自己补充,所以我会先留出模型定义的占位符。 接下来是数据加载器。PyTorch的DataLoader通常需要数据集类,这里可能需要用户根据自己的数据实现Dataset类。我会在代码中留出数据加载的部分,比如使用示例数据或者加载用户自己的数据集。然后是学习率调度器,PyTorch提供了如StepLR、ReduceLROnPlateau等,这里可以选择一个常用的,比如CosineAnnealingLR或者自定义的调度器,同时结合优化器来设置。 梯度可视化部分,用户可能希望在训练过程中实时监控梯度的情况,这可以通过注册钩子函数来实现。比如在前向传播时,记录各层的梯度,然后使用matplotlib进行可视化。需要注意的是,钩子函数可能会影响性能,但用于调试目的还是可行的。 然后,我需要将这些部分整合到一个脚本中,确保各部分正确连接。比如,在训练循环中,每个batch处理后执行梯度计算,并定期调用可视化函数。同时,学习率调度器在每个epoch结束后调整学习率。 还要考虑代码的可读性和注释,方便用户理解各部分的功能。可能还需要添加一些参数,如batch_size、epochs等,让用户可以根据需要调整。另外,梯度可视化可能需要保存图像或显示在窗口中,这里选择在训练过程中每隔一定步骤显示一次。 最后,生成相关问题,帮助用户深入理解相关技术点,比如模型结构、数据加载器的定制、学习率调度器的选择以及梯度消失/爆炸的问题。这些问题需要与用户提供的示例中的问题类似,覆盖脚本中的关键部分。</think>以下是一个使用PyTorch训练UFormer模型的完整脚本示例,包含数据加载器、学习率调度器和梯度可视化功能: ```python import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader, Dataset from torch.optim.lr_scheduler import CosineAnnealingLR import matplotlib.pyplot as plt # 假设的UFormer模型结构(需根据实际论文调整) class UFormer(nn.Module): def __init__(self, input_dim=64, num_heads=8, num_layers=6): super().__init__() self.layers = nn.ModuleList([ nn.TransformerEncoderLayer(d_model=input_dim, nhead=num_heads) for _ in range(num_layers) ]) self.gradient_dict = {} def forward(self, x): for i, layer in enumerate(self.layers): x = layer(x) # 注册梯度钩子 x.register_hook(lambda grad, i=i: self._store_gradient(grad, f"layer_{i}")) return x def _store_gradient(self, grad, name): self.gradient_dict[name] = grad.detach().cpu().numpy() # 自定义数据集类 class MedicalDataset(Dataset): def __init__(self, data_size=1000, input_dim=64): self.data = torch.randn(data_size, input_dim) self.labels = torch.randint(0, 2, (data_size,)) def __len__(self): return len(self.data) def __getitem__(self, idx): return self.data[idx], self.labels[idx] def visualize_gradients(gradient_dict, step): plt.figure(figsize=(10, 6)) for name, grad in gradient_dict.items(): plt.plot(grad.mean(axis=1), label=name) plt.title(f"Gradient Flow at Step {step}") plt.xlabel("Layer Depth") plt.ylabel("Gradient Magnitude") plt.legend() plt.show() # 训练参数 batch_size = 32 epochs = 50 lr = 0.001 # 初始化组件 dataset = MedicalDataset() dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) model = UFormer() criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=lr) scheduler = CosineAnnealingLR(optimizer, T_max=epochs) # 训练循环 for epoch in range(epochs): model.train() for step, (inputs, labels) in enumerate(dataloader): optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() # 每100步可视化梯度 if step % 100 == 0: visualize_gradients(model.gradient_dict, epoch*len(dataloader)+step) model.gradient_dict.clear() # 清空梯度缓存 scheduler.step() print(f"Epoch {epoch+1}/{epochs} completed") ``` ### 关键功能说明: 1. **UFormer模型**:采用Transformer编码器堆叠结构,包含梯度追踪功能[^1] 2. **数据加载器**:使用自定义MedicalDataset模拟医疗文本数据加载 3. **学习率调度**:采用余弦退火调度策略优化训练过程[^2] 4. **梯度可视化**:通过注册梯度钩子实现各层梯度流动可视化
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值