Transformer实战-系列教程10:SwinTransformer 源码解读3(SwinTransformerBlock类)

本文详细解析了SwinTransformer中的SwinTransformerBlock,介绍了其构造函数、前向传播过程以及关键组件如WindowAttention和Mlp的工作原理。重点展示了如何通过窗口划分和可选的窗口位移实现局部注意力机制。
摘要由CSDN通过智能技术生成

🚩🚩🚩Transformer实战-系列教程总目录

有任何问题欢迎在下面留言
本篇文章的代码运行界面均在Pycharm中进行
本篇文章配套的代码资源已经上传
点我下载源码

SwinTransformer 算法原理
SwinTransformer 源码解读1(项目配置/SwinTransformer类)
SwinTransformer 源码解读2(PatchEmbed类/BasicLayer类)
SwinTransformer 源码解读3(SwinTransformerBlock类)
SwinTransformer 源码解读4(WindowAttention类)
SwinTransformer 源码解读5(Mlp类/PatchMerging类)

5、SwinTransformerBlock类

class SwinTransformerBlock(nn.Module):
    def extra_repr(self) -> str:
        return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
               f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"

5.1 构造函数

SwinTransformerBlock 是 Swin Transformer 模型中的一个基本构建块。它结合了自注意力机制和多层感知机(MLP),并通过窗口划分和可选的窗口位移来实现局部注意力

def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
                 act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.input_resolution = input_resolution
        self.num_heads = num_heads
        self.window_size = window_size
        self.shift_size = shift_size
        self.mlp_ratio = mlp_ratio
        if min(self.input_resolution) <= self.window_size:
            self.shift_size = 0
            self.window_size = min(self.input_resolution)
        assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
        self.norm1 = norm_layer(dim)
        self.attn = WindowAttention(
            dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
            qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
        if self.shift_size > 0:
            H, W = self.input_resolution
            img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1
            h_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            w_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            cnt = 0
            for h in h_slices:
                for w in w_slices:
                    img_mask[:, h, w, :] = cnt
                    cnt += 1
            mask_windows = window_partition(img_mask, self.window_size)
            mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
            attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
        else:
            attn_mask = None
        self.register_buffer("attn_mask", attn_mask)
  1. dim:输入特征的通道数。
  2. input_resolution:输入特征的分辨率(高度和宽度)
  3. num_heads:自注意力头的数量
  4. window_size:窗口大小,决定了注意力机制的局部范围
  5. shift_size:窗口位移的大小,用于实现错位窗口多头自注意力(SW-MSA)
  6. mlp_ratio:MLP隐层大小与输入通道数的比率
  7. qkv_bias:QKV的偏置
  8. qk_scale:QKV的缩放因子
  9. drop:丢弃率
  10. drop_path:分别控制QKV的偏差、缩放因子、丢弃率、注意力丢弃率和随机深度率
  11. norm_layer:激活层和标准化层,默认分别为 GELU 和 LayerNorm
  12. WindowAttention:窗口注意力模块
  13. Mlp:一个包含全连接层、激活函数、Dropout的模块
  14. img_mask :图像掩码,用于生成错位窗口自注意力
  15. h_slicesw_slices:水平和垂直方向上的切片,用于划分图像掩码
  16. cnt :计数器,标记不同的窗口
  17. mask_windows :图像掩码划分为窗口,并将每个窗口的掩码重塑为一维向量
  18. window_partition
  19. attn_mask :注意力掩码,用于在自注意力计算中排除窗口外的位置
  20. register_buffer:注意力掩码注册为一个模型的缓冲区

5.2 前向传播

def forward(self, x):
        H, W = self.input_resolution
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"
        shortcut = x
        x = self.norm1(x)
        x = x.view(B, H, W, C)
        if self.shift_size > 0:
            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
        else:
            shifted_x = x
        x_windows = window_partition(shifted_x, self.window_size)
        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)
        attn_windows = self.attn(x_windows, mask=self.attn_mask)
        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
        shifted_x = window_reverse(attn_windows, self.window_size, H, W)
        if self.shift_size > 0:
            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
        else:
            x = shifted_x
        x = x.view(B, H * W, C)
        x = shortcut + self.drop_path(x)
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x
  1. 原始输入: torch.Size([4, 3136, 96]),输入的是一个长度为3136的序列,每个向量的维度为96,在
    被多次调用的时候,维度也发生了变化原始输入: torch.Size([4, 784, 192])、torch.Size([4, 196, 384])、torch.Size([4, 49, 768])
  2. H,W=[ 56,56],输入分辨率中的高度和宽度
  3. B, L, C=[ 4,3136,96],当前输入的维度,批次大小、序列长度和向量的维度
  4. shortcut是用于残差
  5. norm1(x): torch.Size([4, 3136, 96]),经过一个层归一化,维度不变
  6. x.view(B, H, W, C): torch.Size([4, 56, 56, 96]),将序列重塑为(Batch_size,Height,Width,Channel)的形状
  7. shift_size用来判断是否做偏移,最开始的时候shift_size为0,window_partition用来做windows的划分,特征图为[56,56,96],默认窗口为77的大小,因为可以划分出88个窗口
  8. shifted_x: torch.Size([4, 56, 56, 96]),位移操作后的x,此处的偏移只用了torch的内置函数就可以完成,torch.roll()
  9. x_windows: torch.Size([256, 7, 7, 96]),使用窗口划分函数,划分256个窗口,每个窗口7*7个特征,每个特征96维向量
  10. x_windows: torch.Size([256, 49, 96]),将窗口重塑为一维向量,以便进行自注意力计算
  11. attn_windows: torch.Size([256, 7, 7, 96]),WindowAttention处理x,对每个窗口应用窗口注意力机制,考虑到可能的注意力掩码
  12. shifted_x: torch.Size([4, 56, 56, 96]),注意力操作后的窗口重塑回原始形状,并将它们合并回完整的特征图
  13. torch.Size([4, 56, 56, 96]),如果进行了循环位移,则执行逆向循环位移操作,以恢复原始特征图的位置
  14. torch.Size([4, 3136, 96]),特征图重塑回原始的[B, L, C]形状
  15. torch.Size([4, 3136, 96]),应用残差连接,并通过随机深度(如果设置了的话)
  16. torch.Size([4, 3136, 96]),应用第二个标准化层,然后是MLP,并再次应用随机深度,完成残差连接的最后一步。

5.3 窗口划分函数window_partition()

def window_partition(x, window_size):
    B, H, W, C = x.shape
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    return windows

在第1个stage中,不同的stage中会做下采样操作

  1. B, H, W, C =torch.Size([4, 56, 56, 96]),这是函数的原始输入,56*56等于原来的3136
  2. x = torch.Size([4, 8, 7, 8, 7, 96])
  3. windows = torch.Size([256, 7, 7, 96]),256的计算方法为原始数据为5656,要分出77大小为一个窗口,所以窗口数量为(56/7)(56/7)=88=64,而64需要乘上batch=4,所以就是256,向量维度96不变

而Windows Attention就是在这7*7里面的窗口进行计算

SwinTransformer 算法原理
SwinTransformer 源码解读1(项目配置/SwinTransformer类)
SwinTransformer 源码解读2(PatchEmbed类/BasicLayer类)
SwinTransformer 源码解读3(SwinTransformerBlock类)
SwinTransformer 源码解读4(WindowAttention类)
SwinTransformer 源码解读5(Mlp类/PatchMerging类)

  • 16
    点赞
  • 24
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
Visual segmentation is one of the most important tasks in computer vision, which involves dividing an image into multiple segments, each of which corresponds to a different object or region of interest in the image. In recent years, transformer-based methods have emerged as a promising approach for visual segmentation, leveraging the self-attention mechanism to capture long-range dependencies in the image. This survey paper provides a comprehensive overview of transformer-based visual segmentation methods, covering their underlying principles, architecture, training strategies, and applications. The paper starts by introducing the basic concepts of visual segmentation and transformer-based models, followed by a discussion of the key challenges and opportunities in applying transformers to visual segmentation. The paper then reviews the state-of-the-art transformer-based segmentation methods, including both fully transformer-based approaches and hybrid approaches that combine transformers with other techniques such as convolutional neural networks (CNNs). For each method, the paper provides a detailed description of its architecture and training strategy, as well as its performance on benchmark datasets. Finally, the paper concludes with a discussion of the future directions of transformer-based visual segmentation, including potential improvements in model design, training methods, and applications. Overall, this survey paper provides a valuable resource for researchers and practitioners interested in the field of transformer-based visual segmentation.

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

机器学习杨卓越

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值