Swin Transformer之PatchMerging原理及源码

1.图示

 2.原理

Patch Merging层进行下采样。该模块的作用是做降采样,用于缩小分辨率,调整通道数 进而形成层次化的设计,同时也能节省一定运算量。

在CNN中,则是在每个Stage开始前用stride=2的卷积/池化层来降低分辨率。

patch Merging是一个类似于池化的操作,但是比Pooling操作复杂一些。池化会损失信息,patch Merging不会。

每次降采样是两倍,因此在行方向和列方向上,按位置间隔2选取元素,拼成新的patch,再把所有patch都concat起来作为一整个张量,最后展开。此时通道维度会变成原先的4倍(因为H,W各缩小2倍),此时再通过一个全连接层再调整通道维度为原来的两倍。

3.源码

import torch
import torch.nn as nn
import math
import numpy as np


class PatchMerging(nn.Module):
    r""" Patch Merging Layer.
    Args:
        input_resolution (tuple[int]): Resolution of input feature.
        dim (int): Number of input channels.
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """

    def __init__(self, dim, norm_layer=nn.LayerNorm):
        super().__init__()

        self.dim = dim
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
        self.norm = norm_layer(4 * dim)

    def forward(self, x):
        """
        x: B, H*W, C
        """
        B, L, C = x.shape
        H = int(math.sqrt(L))
        W = int(math.sqrt(L))

        assert L == H * W, "input feature has wrong size"
        assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."

        x = x.view(B, H, W, C)
        print('--------------------------')
        print(x)
        print('原始图像4D维度:',x.shape)

        # 在行和列方向上间隔1选取元素
        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
        print('--------------------------')
        print(x0)
        print('切分图像4D维度:',x0.shape)
        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
        print('--------------------------')
        print(x1)
        print('切分图像4D维度:',x1.shape)
        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
        print('--------------------------')
        print(x2)
        print('切分图像4D维度:',x2.shape)
        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
        print('--------------------------')
        print(x3)
        print('切分图像4D维度:',x3.shape)

        # 拼接到一起作为一整个张量
        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
        print('--------------------------')
        print(x)
        print('拼接整个张量后:',x.shape)
        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C
        print('--------------------------')
        print(x)
        print('合并行和列后:',x.shape)

        x = self.norm(x)           # 归一化操作
        print('--------------------------')
        print(x)
        print('归一化操作后:', x.shape)
        x = self.reduction(x)      # 降维,通道降低2倍
        print('--------------------------')
        print(x)
        print('通道降低2倍后:', x.shape)

        return x
if __name__ == "__main__":

    x = np.array([[0, 2, 0, 2],[ 1, 3, 1, 3 ],[ 0, 2, 0, 2 ],[ 1, 3, 1, 3 ]])
    x = torch.from_numpy(x)
    x = x.view(1, 4*4, 1)
    x=x.to(torch.float32)
    model = PatchMerging(1)
    print('--------------------------')
    print(x)
    print('原始图像3D维度:', x.shape)
    y = model(x)


  • 21
    点赞
  • 69
    收藏
    觉得还不错? 一键收藏
  • 6
    评论
Transformer在许多NLP(自然语言处理)任务中取得了最先进的成果。 Swin Transformer是在ViT基础上发展而来,是Transformer应用于CV(计算机视觉)领域又一里程碑式的工作。它可以作为通用的骨干网络,用于图片分类的CV任务,以及下游的CV任务,如目标检测、实例分割、语义分割等,并取得了SOTA的成果。Swin Transformer获得了ICCV 2021的最佳论文奖。本课程对Swin Transformer原理与PyTorch实现代码进行精讲,来帮助大家掌握其详细原理和具体实现;并且使用Swin Transformer对17个类别花朵数据集进行图片分类的项目实战。  Ÿ   原理精讲部分包括:Transformer的架构概述、Transformer的Encoder 、Transformer的Decoder、Swin Transformer的网络架构、Patch Merging、SW-MSA、Relative Position Bias、MSA与W-MSA计算量分析、实验结果及性能。 Ÿ   项目实战部分包括:安装软件环境和PyTorch、安装Swin-Transformer、数据集自动划分、修改配置文件、训练数据集、测试训练出的网络模型。Ÿ   代码精讲部分使用PyCharm对Swin Transformer的PyTorch代码进行逐行解读,包括:PatchEmbed、SwinTransformerBlock、PatchMerging、推理过程和训练过程实现代码解读。 相关课程:Transformer原理与代码精讲(PyTorch)https://edu.csdn.net/course/detail/36697Transformer原理与代码精讲(TensorFlow)https://edu.csdn.net/course/detail/36699ViT(Vision Transformer原理与代码精讲 https://edu.csdn.net/course/detail/36719DETR原理与代码精讲 https://edu.csdn.net/course/detail/36768Swin Transformer实战目标检测:训练自己的数据集 https://edu.csdn.net/course/detail/36585Swin Transformer实战实例分割:训练自己的数据集 https://edu.csdn.net/course/detail/36586 

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值