图像修复-CVPR2021-恶劣天气图像修复-TransWeather: Transformer-based Restoration of Images Degraded...

图像修复-恶劣天气图像修复-CVPR2021-TransWeather: Transformer-based Restoration of Images Degraded by Adverse Weather Conditions

TransWeather 提出了一种基于 Transformer 的端到端模型,采用 Intra-Patch Transformer 结构增强局部注意力,并引入可学习的天气类型嵌入,仅用单个编码器和解码器即可高效去除多种恶劣天气条件,在多个数据集上显著超越现有方法。

论文链接:TransWeather: Transformer-based Restoration of Images Degraded by Adverse Weather Conditions

代码链接:/jeya-maria-jose/TransWeather

一、主要创新点
  1. 统一的 Transformer 结构:提出了一种基于 Transformer 的端到端模型 TransWeather,仅使用单个编码器和解码器即可去除所有类型的恶劣天气条件(如雨、雾、雪),相比于 All-in-One 方法使用多个编码器的设计更加高效。

  2. 引入 Intra-Patch Transformer Blocks:在 Transformer 编码器中采用 Intra-Patch Transformer 结构,增强图像块(patch)内部的注意力机制,从而更有效地去除较小尺度的天气退化影响,提高图像修复能力。

  3. 可学习的天气类型嵌入(Weather Type Embeddings):在 Transformer 解码器中引入了可学习的天气类型嵌入,使模型能够自适应不同的天气退化情况,从而提升去天气效果(主要是借鉴思想是Detr的cross attention 的块查询机制)。

  4. 性能突破:TransWeather 在多个测试数据集上显著超越了 All-in-One 方法以及针对特定天气类型优化的去天气方法,例如:

    • 在 Test1(雨+雾)数据集上提升 +6.34 PSNR

    • 在 SnowTest100K-L 数据集上提升 +4.93 PSNR

    • 在 RainDrop 测试集上提升 +3.11 PSNR

二、模型架构图

从图中可以看出主要由三部分组成:

  1. Transformer Encoder(编码器):Transformer Encoder 采用层次化结构,结合多头自注意力和前馈神经网络(FFN)进行特征提取,并通过重叠补丁合并(Overlapped Patch Merging)保持特征一致性。为了提高计算效率,模型使用缩减比 R 降低自注意力计算复杂度,并在 FFN 内引入深度卷积(DW-C)以增强局部特征学习。此外,设计了 Intra-Patch Transformer Block(Intra-PT),在各阶段内部对特征划分子补丁进行额外的自注意力计算,专门提取小尺度特征,以更好地去除雨滴、轻雾等细微降质。最终,Intra-PT 结果与主 Transformer Block 特征融合,实现高效的全局-局部特征提取,从而在各种天气降质场景下取得更优的图像恢复效果。
  2. Transformer Decoder(解码器):Transformer Decoder 受 Transformer 和 DETR 结构启发,采用可学习的天气类型查询(weather type queries)来解码任务,并预测任务特征向量,以恢复受天气影响的图像。与传统自注意力机制不同,该 Decoder 采用查询(Q)作为天气类型嵌入,而键(K)和值(V)来自 Transformer Encoder 的最后阶段特征Decoder 通过单阶段多块结构提取和融合特征,并结合 Transformer Encoder 的各层信息,最终通过卷积尾部(Convolutional Tail)重建清晰图像。这种设计使模型能够适应多种天气降质情况,并提升恢复质量。
  3. Convolutional Projection Block(卷积映射恢复):模型将 Transformer Encoder 的层次化特征与 Transformer Decoder 任务特征作为输入,并通过 四层卷积 逐步重建清晰图像。每层卷积前加入 上采样层 以恢复原始分辨率,同时在卷积尾部引入 跳跃连接,与 Transformer Encoder 进行特征融合,以提高恢复能力。最后,采用 tanh 激活函数约束输出,确保生成的去天气降质图像具有稳定的像素分布。

下面将从源码层面解释Transformer Encoder(编码器)和Transformer Decoder(解码器),对于Convolutional Projection Block(卷积映射恢复)只是简单的卷积还原操作。

在这里插入图片描述

三、Transformer Encoder
  1. 层次化特征提取:在 Transformer Encoder 内部,通过多级结构提取不同层次的特征,使模型能够捕捉低级和高级信息,提高对各种天气降质的恢复能力。
  2. 重叠补丁合并(Overlapped Patch Merging):在每个阶段,通过合并重叠补丁(patches)来保持特征尺寸一致,同时提高特征表达能力,减少信息损失。
  3. Transformer Block 计算:使用多头自注意力(MSA)计算输入特征的自注意力信息,并通过前馈神经网络(FFN)进一步处理,使模型能够学习到全局和局部信息。
  4. 计算复杂度优化:为了降低计算量,引入缩减比 R,将自注意力计算复杂度从 O(N2)降至 O(N2R),优化计算效率。
  5. 改进 FFN 结构:在 FFN 内部引入深度卷积(Depth-Wise Convolution, DW-C),增强 Transformer 位置感知能力,同时保留局部信息,提高去天气降质效果。
  6. Intra-Patch Transformer Block(Intra-PT):在主 Transformer Block 之间额外加入 Intra-PT 模块,该模块对输入特征划分子补丁(sub-patches),专门用于提取小尺度特征,有助于消除微小降质影响,如小雨滴或轻雾。
  7. 融合 Intra-PT 结果:Intra-PT 的自注意力特征与主 Transformer Block 提取的特征在同一阶段进行融合,使最终特征表达更加丰富,提高图像恢复质量。
Transformer Encoder源码
# 来自源码文件:https://github.com/jeya-maria-jose/TransWeather/blob/main/transweather_model.py
class EncoderTransformer(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512],
                 num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0.,
                 attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm,
                 depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1]):
        super().__init__()
        self.num_classes = num_classes
        self.depths = depths

        # 定义不同阶段的 Patch Embedding
        self.patch_embed1 = OverlapPatchEmbed(img_size=img_size, patch_size=7, stride=4, in_chans=in_chans,
                                              embed_dim=embed_dims[0])
        self.patch_embed2 = OverlapPatchEmbed(img_size=img_size // 4, patch_size=3, stride=2, in_chans=embed_dims[0],
                                              embed_dim=embed_dims[1])
        self.patch_embed3 = OverlapPatchEmbed(img_size=img_size // 8, patch_size=3, stride=2, in_chans=embed_dims[1],
                                              embed_dim=embed_dims[2])
        self.patch_embed4 = OverlapPatchEmbed(img_size=img_size // 16, patch_size=3, stride=2, in_chans=embed_dims[2],
                                              embed_dim=embed_dims[3])

        # 定义 Intra-patch transformer blocks 的 Patch Embedding
        self.mini_patch_embed1 = OverlapPatchEmbed(img_size=img_size // 4, patch_size=3, stride=2, in_chans=embed_dims[0],
                                                   embed_dim=embed_dims[1])
        self.mini_patch_embed2 = OverlapPatchEmbed(img_size=img_size // 8, patch_size=3, stride=2, in_chans=embed_dims[1],
                                                   embed_dim=embed_dims[2])
        self.mini_patch_embed3 = OverlapPatchEmbed(img_size=img_size // 16, patch_size=3, stride=2, in_chans=embed_dims[2],
                                                   embed_dim=embed_dims[3])
        self.mini_patch_embed4 = OverlapPatchEmbed(img_size=img_size // 32, patch_size=3, stride=2, in_chans=embed_dims[0],
                                                   embed_dim=embed_dims[3])

        # 定义主编码器部分的 Transformer Blocks
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # 计算 Stochastic Depth decay
        cur = 0
        self.block1 = nn.ModuleList([
            Block(dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias,
                  qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i],
                  norm_layer=norm_layer, sr_ratio=sr_ratios[0])
            for i in range(depths[0])])
        self.norm1 = norm_layer(embed_dims[0])

        # 定义 Intra-patch Transformer Block
        self.patch_block1 = nn.ModuleList([
            Block(dim=embed_dims[1], num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias,
                  qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur],
                  norm_layer=norm_layer, sr_ratio=sr_ratios[0])
        ])
        self.pnorm1 = norm_layer(embed_dims[1])

        # 继续构建后续层的 Transformer Blocks
        cur += depths[0]
        self.block2 = nn.ModuleList([...])  # 类似 block1 的结构
        self.norm2 = norm_layer(embed_dims[1])
        self.patch_block2 = nn.ModuleList([...])
        self.pnorm2 = norm_layer(embed_dims[2])
        cur += depths[1]
        self.block3 = nn.ModuleList([...])
        self.norm3 = norm_layer(embed_dims[2])
        self.patch_block3 = nn.ModuleList([...])
        self.pnorm3 = norm_layer(embed_dims[3])
        cur += depths[2]
        self.block4 = nn.ModuleList([...])
        self.norm4 = norm_layer(embed_dims[3])

        self.apply(self._init_weights)  # 初始化权重

    def _init_weights(self, m):
        """初始化网络权重"""
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()

    def forward_features(self, x):
        """前向传播特征提取"""
        B = x.shape[0]
        outs = []
        # Stage 1
        x1, H1, W1 = self.patch_embed1(x)
        x2, H2, W2 = self.mini_patch_embed1(x1.permute(0, 2, 1).reshape(B, 64, H1, W1))
        for blk in self.block1:
            x1 = blk(x1, H1, W1)
        x1 = self.norm1(x1).reshape(B, H1, W1, -1).permute(0, 3, 1, 2).contiguous()
        outs.append(x1)

        # 类似的处理 Stage 2, 3, 4
        return outs

    def forward(self, x):
        """前向传播"""
        return self.forward_features(x)

四、Transformer Decoder
  1. 借鉴自 Transformer 和 DETR:原始 Transformer Decoder 采用自回归(autoregressive)方式逐步预测序列,而 DETR 使用对象查询(object queries)来解码目标框坐标和类别标签。受此启发,本文引入天气类型查询(weather type queries),用于解码任务并预测任务特征向量,以恢复干净图像。
  2. 天气类型查询(Weather Type Queries):这些查询是可学习的嵌入(learnable embeddings),在模型训练过程中与其他参数一起学习,使模型能够适应不同的天气降质类型。
  3. 注意力机制:Transformer Decoder 关注来自 Transformer Encoder 的特征输出,通过查询机制提取任务相关的信息。
  4. 单阶段多块结构:与典型的 Encoder-Decoder 结构类似,该 Transformer Decoder 仅在单一阶段运行,但包含多个 Transformer Block,每个块进行特征提取和融合。
  5. 查询-键值(QKV)设计:不同于传统的自注意力(self-attention)机制,该 Transformer Decoder 采用 Q(查询)为天气类型查询的可学习嵌入,而 K 和 V(键和值)分别来自 Transformer Encoder 的最终阶段特征。
  6. 任务特征向量:Decoder 输出的解码特征表示任务特征向量(task feature vector),该特征向量与 Transformer Encoder 各阶段的特征融合,以增强恢复能力。

在这里插入图片描述

Transformer Decoder源码
# 来自源码文件:https://github.com/jeya-maria-jose/TransWeather/blob/main/transweather_model.py
class DecoderTransformer(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512],
                 num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0.,
                 attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm,
                 depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1]):
        """
        Transformer 解码器

        参数:
        - img_size: 输入图像尺寸
        - patch_size: Patch 的大小
        - in_chans: 输入通道数
        - num_classes: 分类类别数
        - embed_dims: 每个阶段的嵌入维度
        - num_heads: 每个阶段的注意力头数
        - mlp_ratios: MLP 扩展比例
        - qkv_bias: 是否使用 QKV 偏置
        - qk_scale: QK 缩放因子
        - drop_rate: Dropout 比例
        - attn_drop_rate: 注意力 Dropout 比例
        - drop_path_rate: Drop Path 比例
        - norm_layer: 归一化层类型
        - depths: 每个阶段的 Transformer Block 数量
        - sr_ratios: 空间降采样率
        """
        super().__init__()
        self.num_classes = num_classes
        self.depths = depths

        # Patch 嵌入层 (用于处理 Encoder 的输出)
        self.patch_embed1 = OverlapPatchEmbed(img_size=img_size // 16, patch_size=3, stride=2, 
                                              in_chans=embed_dims[3], embed_dim=embed_dims[3])

        # 计算 Drop Path 比例
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  
        cur = 0

        # Transformer 解码器第一阶段
        self.block1 = nn.ModuleList([
            Block_dec(dim=embed_dims[3], num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, 
                      qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], 
                      norm_layer=norm_layer, sr_ratio=sr_ratios[3])
            for i in range(depths[0])
        ])
        self.norm1 = norm_layer(embed_dims[3])

        cur += depths[0]

        # 初始化权重
        self.apply(self._init_weights)

    def _init_weights(self, m):
        """
        权重初始化方法
        """
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()

    def init_weights(self, pretrained=None):
        """
        加载预训练权重
        """
        if isinstance(pretrained, str):
            logger = get_root_logger()
            load_checkpoint(self, pretrained, map_location='cpu', strict=False, logger=logger)

    def reset_drop_path(self, drop_path_rate):
        """
        重新设置 Drop Path 率
        """
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))]
        cur = 0
        for i in range(self.depths[0]):
            self.block1[i].drop_path.drop_prob = dpr[cur + i]

        cur += self.depths[0]
        for i in range(self.depths[1]):
            self.block2[i].drop_path.drop_prob = dpr[cur + i]

        cur += self.depths[1]
        for i in range(self.depths[2]):
            self.block3[i].drop_path.drop_prob = dpr[cur + i]

        cur += self.depths[2]
        for i in range(self.depths[3]):
            self.block4[i].drop_path.drop_prob = dpr[cur + i]

    def forward_features(self, x):
        """
        提取解码器的特征
        """
        x = x[3]  # 选取 Encoder 的最后一层特征
        B = x.shape[0]  # 获取 batch size
        outs = []

        # 第一阶段解码
        x, H, W = self.patch_embed1(x)
        for blk in self.block1:
            x = blk(x, H, W)
        x = self.norm1(x)
        x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
        outs.append(x)

        return outs

    def forward(self, x):
        """
        前向传播
        """
        x = self.forward_features(x)
        # x = self.head(x)  # 这里可以加上分类头部
        return x

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值