Swin-Conv-UNet

本文希望从网络架构设计和新颖的训练数据合成来提升去噪性能。作者提出了 Swin-Conv-UNet 盲去噪模型。整体结构来自 UNet,模块的设计思想结合了 DRUNet 和 SwinIR。

基于 Swin-Conv-UNet 结果和数据分析的盲去噪方法

论文名称:Practical Blind Denoising via Swin-Conv-UNet and Data Synthesis 

论文地址:https://arxiv.org/pdf/2203.13278.pdf

盲去噪任务介绍

作为基本的图像复原问题,图像去噪 (Image Denoising) 问题的目标是从噪声图片中重建清晰的图像,越来越受到人们的关注。它是诸多视觉任务的第一步,且能够帮助评估不同图像先验和优化算法的有效性。目前已经提出了基于深度学习的方法的多种网络架构和超分网络的训练策略来改善去噪的性能。顾名思义,Image Denoising 任务需要两张图片,一张清晰的图和一张带噪声的图。去噪模型的目的是根据后者生成前者,而退化模型的目的是根据前者生成后者。经典去噪任务认为:噪声类型和噪声水平 (noise type and noise level) 是已知的。 但是,在实际应用中,这种噪声作用十分复杂,不但强度水平未知,而且噪声类型也难以简单建模 (白噪声,JPEG 压缩噪声,泊松噪声,相机传感器噪声等)。这种噪声类型或噪声水平未知的超分任务我们称之为盲去噪任务 (Blind Image Denoising)。对于这类任务的深度学习方法目前有两大类:其一是把噪声类型简单建模为加性高斯白噪声 (Additive White Gaussian Noise, AWGN),并想办法改进模型,因为网络架构可以帮助捕获图像先验来提升去噪任务的性能;其二是去关注训练数据或者噪声建模。但是,以 AWGN 简单建模的训练样本和真实图像之间存在一个域差。以 AWGN 为噪声训练得到的网络在实际应用时,这种域差距将导致比较糟糕的性能。

从模型架构和噪声类型建模两个方面提升去噪性能

已有一些相关的工作希望建模更接近真实世界噪声,通过利用数字传感器的物理特性和成像流水线的步骤,已有的工作[1]设计了一种相机传感器噪声合成方法,并提供了一种有效的深度原始图像去噪模型。但是它主要关注相机传感器引起的噪声,而缺乏对通用盲去噪方法的思考。

去噪模型的架构可以帮助捕获图像先验来提升去噪任务的性能,所以也是值得考虑的因素之一。

所以这个工作试图通过新颖的网络架构设计和新颖的训练数据合成来提升去噪性能。对于网络架构设计的部分,不同类型的架构具有互补的图像先验的捕捉能力,可以结合使用以提高性能。因此,作者考虑了两种模型,分别是 DRUNet 和 SwinIR。作者提出了一个 Swin-Conv 模块,以结合残差卷积层的 local 建模能力和 Swin 模块的 non-local 建模能力,然后将其作为主要构建模块插入到 UNet 模型架构中。对于噪声模拟部分,作者同时考虑了加性高斯白噪声,泊松噪声,斑点噪声,JPEG 压缩噪声和相机传感器噪声等类型,以模拟真实世界噪声。

图像复原问题的建模

可以看出解决盲去噪的关键在于对 noise 图像的退化过程进行建模以及对 clean 图像的先验设计。

很明显,退化过程是由来自训练数据的噪声图像隐式定义的,这表明训练数据的噪声图像负责深度盲去噪模型来捕捉退化过程的知识。为了提高深度盲去噪模型的图像先验建模能力,应该重点改善以下三个因素,包括网络结构、模型大小和干净的图像训练数据。很明显,退化过程是由训练数据的噪声图像隐式定义的这表明训练数据的噪声图像负责 Blind Imag

### 使用Swin-UNet进行3D数据集训练 为了使用Swin-UNet模型训练自定义的3D数据集,需考虑几个关键方面来调整和优化该架构以适应三维输入。以下是详细的说明: #### 1. 数据预处理与加载 对于3D数据集而言,数据预处理至关重要。考虑到医疗影像或其他类型的体积数据通常具有较大的尺寸,在内存有限的情况下可能难以一次性加载整个体素立方体。因此,建议采用分片策略或滑动窗口技术逐块读取并处理数据。 ```python import numpy as np from torch.utils.data import Dataset, DataLoader class Custom3DDataset(Dataset): def __init__(self, data_paths, label_paths=None, transform=None): self.data_paths = data_paths self.label_paths = label_paths self.transform = transform def __len__(self): return len(self.data_paths) def __getitem__(self, idx): volume_data = np.load(self.data_paths[idx]) # 假设为numpy数组形式存储 if self.label_paths is not None: mask_data = np.load(self.label_paths[idx]) sample = {'image': volume_data, 'label': mask_data} if self.transform: sample = self.transform(sample) return sample['image'], sample['label'] ``` 此代码片段展示了一个简单的`Custom3DDataset`类实现,适用于加载3D医学图像及其对应标注[^1]。 #### 2. 修改网络结构支持3D卷积操作 原始版本中的二维卷积层需要替换为相应的三维版本(`Conv3d`, `BatchNorm3d`)以便能够接收并处理形状类似于`(batch_size, channels, depth, height, width)`的数据张量作为输入。 ```python import torch.nn as nn def conv_block(in_channels, out_channels): return nn.Sequential( nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1), nn.BatchNorm3d(out_channels), nn.ReLU(inplace=True), nn.MaxPool3d(kernel_size=(1, 2, 2)) ) ``` 上述函数定义了一个基础的3D卷积模块,可以根据实际需求进一步扩展成完整的编码器解码器框架[^2]。 #### 3. 自定义损失函数设计 针对特定应用场景(如二分类),可以构建专门定制化的损失计算方式。例如Dice Loss常用于衡量预测结果与真实标签之间的相似度,尤其适合解决类别不平衡问题。 ```python import torch def dice_loss(preds, targets, smooth=1e-6): preds = preds.view(-1) targets = targets.view(-1) intersection = (preds * targets).sum() union = preds.sum() + targets.sum() score = (2.*intersection + smooth)/(union + smooth) loss = 1 - score return loss ``` 这段Python脚本实现了Dice系数的变体——Dice Loss,它有助于提高分割任务的效果尤其是当目标区域较小的时候。 #### 4. 训练过程配置 最后一步涉及设置好所有超参数以及调用合适的优化算法来进行迭代更新权重直至收敛。这里推荐使用AdamW优化器配合余弦退火调度机制来加速收敛速度同时防止过拟合现象发生。 ```python model = SwinUNET_3D().cuda() optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs) for epoch in range(num_epochs): model.train() running_loss = 0. for images, labels in train_loader: optimizer.zero_grad() outputs = model(images.cuda()) loss = criterion(outputs, labels.cuda()) + lambda_dice*dice_loss(outputs, labels.cuda()) loss.backward() optimizer.step() scheduler.step() running_loss += loss.item() ``` 以上代码段展示了典型的PyTorch风格的训练循环逻辑,其中包含了前向传播、反向传播及参数更新的核心步骤[^4]。
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值