SwinTransformer学习记录(一)之整体架构

本文详细介绍了SwinTransformer,一种改进自ViT的高效模型,通过窗口注意力机制降低计算复杂度,特别关注其结构设计、计算优化和在目标检测等领域的应用。
摘要由CSDN通过智能技术生成

SwinTransformer自问世以来,凭借其优秀的性能,受到无数研究者的青睐,因此作为一个通用的骨干网络,其再目标检测,语义分割,去噪等领域大杀四方,可谓是风光无限,今天,我们便来一睹SwinTranformer的风采。

SwinTransformer是在ViT的基础上进行改进的,但ViT直接使用Transformer,由于其计算复杂度极高,因此需要消耗极大的计算代价,正因如此,SwinTransformer的设计才显得如此巧妙,SwinTransformer最大的特点便是将注意力计算限制在一个个窗口内容,从而大幅的减少了计算量,相比于PVT使用下采样的方式来缩减KV维度,从而减少计算量,SwinTransformer的设计更为复杂,接下来我们便进入正题,开始SwinTransformer模型的学习,博主使用的是swin_T_224_1k版本,这是Swin家族最为轻量级的了,话不多说,开始了。

整体架构

首先给出整体架构,从图中可以看到,与PVT网络相同,其分为4个阶段(每个阶段的输出特征图皆不相同。除第一阶段外,每个阶段都有一个Patch Merging模块,该模型块的作用便是用于缩减特征图,因为Transformer在进行计算时是不会改变特征图大小的,那么要获取多尺度特征,就需要Patch Merging模块了,这里的patch的作用,与PVT中的Patch Embedding,抑或是ViT中的patch都是相同的,只是构造上有所不同而已。
在这里插入图片描述

除了Patch Merging模块,接下来便是Swin Transformer Block模块了,这才是重头戏,其主要包含LayerNormWindow Attention(W-MSA)Shifted Window Attention(SW-MSA)MLP模块。为方便对整个架构的理解,我们先从外部梳理一遍其具体变换:

Swin Transformer整体外部变换过程

def forward_raw(self, x):
        """Forward function."""
        x = self.patch_embed(x)
        Wh, Ww = x.size(2), x.size(3)
        if self.ape:
            # interpolate the position embedding to the corresponding size
            absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic')
            x = (x + absolute_pos_embed).flatten(2).transpose(1, 2)  # B Wh*Ww C
        else:
            x = x.flatten(2).transpose(1, 2)
        x = self.pos_drop(x)
        outs = []
        for i in range(self.num_layers):
            layer = self.layers[i]
            x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
            if i in self.out_indices:
                norm_layer = getattr(self, f'norm{i}')
                x_out = norm_layer(x_out)
                out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
                outs.append(out)
        return tuple(outs)

输入:x torch.Size([2, 3, 640, 480])
经过Patch Embed后变为:torch.Size([2, 64, 160, 120]),这里的64是我们自己设定的,然后宽高分别缩减为原来的四分之一。

x = self.patch_embed(x)
Wh, Ww = x.size(2), x.size(3) 记录此时的特征图大小:160, 120

随后是判断是否进行位置编码,这里用ape来表示,默认为False
随后将 x 展平并变换维度位置:x = x.flatten(2).transpose(1, 2) 得到:torch.Size([2, 19200, 64])
随后便是进入各个特征提取阶段,共有4个。

for i in range(self.num_layers):
            layer = self.layers[i]
            x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
            if i in self.out_indices:
                norm_layer = getattr(self, f'norm{i}')
                x_out = norm_layer(x_out)
                out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
                outs.append(out)

其核心代码即:x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
第一阶段:x_out:torch.Size([2, 19200, 64]),out:torch.Size([2, 64, 160, 120])

在这里插入图片描述
第二阶段:x_out:torch.Size([2, 4800, 64]),out:torch.Size([2, 64, 80, 60])

在这里插入图片描述
第三阶段:x_out:torch.Size([2, 1200, 256]),out:torch.Size([2,256, 40, 30])

在这里插入图片描述
第三阶段:x_out:torch.Size([2, 1200, 256]),out:torch.Size([2,256, 40, 30]),与第三阶段相同

在这里插入图片描述

可以看到,这里的输出特征图并没有严格与整体图一致,我们以代码为准。

四个特征提取阶段的具体构造如下:不要轻易打开,很多
然而在对照下面的模型时却发现,该模块里面似乎没有Shifted Window Attention(SW-MSA),而且在代码的定义中,似乎也没有与之相匹配的定义,这是由于Shifted Window Attention(SW-MSA)事实上可以通过 Window Attention(W-MSA)来实现,只需要给定一个参数shift-size即可。而shift-size的设定则与windows-size有关,如下图所示:

在这里插入图片描述
然而从给出的模型结构图上,两者似乎没有区别。

ModuleList(
  (0): BasicLayer(
    (blocks): ModuleList(
      (0): SwinTransformerBlock(
        (norm1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        (attn): WindowAttention(
          (qkv): Linear(in_features=64, out_features=192, bias=True)
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=64, out_features=64, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
          (softmax): Softmax(dim=-1)
        )
        (drop_path): Identity()
        (norm2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=64, out_features=256, bias=True)
          (act): GELU()
          (fc2): Linear(in_features=256, out_features=64, bias=True)
          (drop): Dropout(p=0.0, inplace=False)
        )
      )
      (1): SwinTransformerBlock(
        (norm1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        (attn): WindowAttention(
          (qkv): Linear(in_features=64, out_features=192, bias=True)
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=64, out_features=64, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
          (softmax): Softmax(dim=-1)
        )
        (drop_path): DropPath(drop_prob=0.018)
        (norm2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=64, out_features=256, bias=True)
          (act): GELU()
          (fc2): Linear(in_features=256, out_features=64, bias=True)
          (drop): Dropout(p=0.0, inplace=False)
        )
      )
    )
    (downsample): PatchMerging(
      (reduction): Linear(in_features=256, out_features=128, bias=False)
      (norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
    )
  )
  (1): BasicLayer(
    (blocks): ModuleList(
      (0): SwinTransformerBlock(
        (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (attn): WindowAttention(
          (qkv): Linear(in_features=128, out_features=384, bias=True)
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=128, out_features=128, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
          (softmax): Softmax(dim=-1)
        )
        (drop_path): DropPath(drop_prob=0.036)
        (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=128, out_features=512, bias=True)
          (act): GELU()
          (fc2): Linear(in_features=512, out_features=128, bias=True)
          (drop): Dropout(p=0.0, inplace=False)
        )
      )
      (1): SwinTransformerBlock(
        (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (attn): WindowAttention(
          (qkv): Linear(in_features=128, out_features=384, bias=True)
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=128, out_features=128, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
          (softmax): Softmax(dim=-1)
        )
        (drop_path): DropPath(drop_prob=0.055)
        (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=128, out_features=512, bias=True)
          (act): GELU()
          (fc2): Linear(in_features=512, out_features=128, bias=True)
          (drop): Dropout(p=0.0, inplace=False)
        )
      )
    )
    (downsample): PatchMerging(
      (reduction): Linear(in_features=512, out_features=256, bias=False)
      (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    )
  )
  (2): BasicLayer(
    (blocks): ModuleList(
      (0): SwinTransformerBlock(
        (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (attn): WindowAttention(
          (qkv): Linear(in_features=256, out_features=768, bias=True)
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=256, out_features=256, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
          (softmax): Softmax(dim=-1)
        )
        (drop_path): DropPath(drop_prob=0.073)
        (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=256, out_features=1024, bias=True)
          (act): GELU()
          (fc2): Linear(in_features=1024, out_features=256, bias=True)
          (drop): Dropout(p=0.0, inplace=False)
        )
      )
      (1): SwinTransformerBlock(
        (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (attn): WindowAttention(
          (qkv): Linear(in_features=256, out_features=768, bias=True)
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=256, out_features=256, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
          (softmax): Softmax(dim=-1)
        )
        (drop_path): DropPath(drop_prob=0.091)
        (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=256, out_features=1024, bias=True)
          (act): GELU()
          (fc2): Linear(in_features=1024, out_features=256, bias=True)
          (drop): Dropout(p=0.0, inplace=False)
        )
      )
      (2): SwinTransformerBlock(
        (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (attn): WindowAttention(
          (qkv): Linear(in_features=256, out_features=768, bias=True)
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=256, out_features=256, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
          (softmax): Softmax(dim=-1)
        )
        (drop_path): DropPath(drop_prob=0.109)
        (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=256, out_features=1024, bias=True)
          (act): GELU()
          (fc2): Linear(in_features=1024, out_features=256, bias=True)
          (drop): Dropout(p=0.0, inplace=False)
        )
      )
      (3): SwinTransformerBlock(
        (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (attn): WindowAttention(
          (qkv): Linear(in_features=256, out_features=768, bias=True)
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=256, out_features=256, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
          (softmax): Softmax(dim=-1)
        )
        (drop_path): DropPath(drop_prob=0.127)
        (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=256, out_features=1024, bias=True)
          (act): GELU()
          (fc2): Linear(in_features=1024, out_features=256, bias=True)
          (drop): Dropout(p=0.0, inplace=False)
        )
      )
      (4): SwinTransformerBlock(
        (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (attn): WindowAttention(
          (qkv): Linear(in_features=256, out_features=768, bias=True)
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=256, out_features=256, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
          (softmax): Softmax(dim=-1)
        )
        (drop_path): DropPath(drop_prob=0.145)
        (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=256, out_features=1024, bias=True)
          (act): GELU()
          (fc2): Linear(in_features=1024, out_features=256, bias=True)
          (drop): Dropout(p=0.0, inplace=False)
        )
      )
      (5): SwinTransformerBlock(
        (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (attn): WindowAttention(
          (qkv): Linear(in_features=256, out_features=768, bias=True)
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=256, out_features=256, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
          (softmax): Softmax(dim=-1)
        )
        (drop_path): DropPath(drop_prob=0.164)
        (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=256, out_features=1024, bias=True)
          (act): GELU()
          (fc2): Linear(in_features=1024, out_features=256, bias=True)
          (drop): Dropout(p=0.0, inplace=False)
        )
      )
    )
  )
  (3): BasicLayer(
    (blocks): ModuleList(
      (0): SwinTransformerBlock(
        (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (attn): WindowAttention(
          (qkv): Linear(in_features=256, out_features=768, bias=True)
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=256, out_features=256, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
          (softmax): Softmax(dim=-1)
        )
        (drop_path): DropPath(drop_prob=0.182)
        (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=256, out_features=1024, bias=True)
          (act): GELU()
          (fc2): Linear(in_features=1024, out_features=256, bias=True)
          (drop): Dropout(p=0.0, inplace=False)
        )
      )
      (1): SwinTransformerBlock(
        (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (attn): WindowAttention(
          (qkv): Linear(in_features=256, out_features=768, bias=True)
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=256, out_features=256, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
          (softmax): Softmax(dim=-1)
        )
        (drop_path): DropPath(drop_prob=0.200)
        (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=256, out_features=1024, bias=True)
          (act): GELU()
          (fc2): Linear(in_features=1024, out_features=256, bias=True)
          (drop): Dropout(p=0.0, inplace=False)
        )
      )
    )
  )
)

接下来对其逐一介绍。

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
基于Swin Transformer的图像分割是一种利用Swin Transformer模型进行图像语义分割的方法。Swin Transformer是一种基于注意力机制的深度学习模型,它在自然语言处理和计算机视觉任务中取得了很好的效果。 在图像分割任务中,Swin Transformer结合了CNN和Transformer的优势,既能够进行全局建模,又具备定位能力。为了进一步提高性能,研究人员提出了两种基于Swin Transformer的图像分割方法:TransUnet和TransFuse。 TransUnet是一种将CNN和Transformer顺序堆叠的编码器结构。它利用CNN提取图像的低级特征,并将其作为输入传递给Transformer进行高级特征的建模和语义分割。 TransFuse是一种同时执行CNN和Transformer功能的混合结构。它利用Swin Transformer提取图像的全局特征,并使用简单的渐进式上采样恢复空间分辨率。 这些基于Swin Transformer的图像分割方法在遥感图像和医学图像等领域取得了很好的效果,能够准确地分割出图像中的不同语义区域。 以下是一个基于Swin Transformer的图像分割的示例代码: ```python import torch import torch.nn as nn from torchvision.models import resnet50 from swin_transformer import SwinTransformer class SwinUnet(nn.Module): def __init__(self, num_classes): super(SwinUnet, self).__init__() self.backbone = SwinTransformer() self.decoder = nn.Sequential( nn.Conv2d(1024, 512, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(512, 256, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(256, num_classes, kernel_size=1) ) def forward(self, x): x = self.backbone(x) x = self.decoder(x) return x # 创建模型实例 model = SwinUnet(num_classes=2) # 加载预训练权重 checkpoint = torch.load('swin_unet.pth') model.load_state_dict(checkpoint['model_state_dict']) # 输入图像 input_image = torch.randn(1, 3, 256, 256) # 进行图像分割 output = model(input_image) # 输出分割结果 print(output.shape) # 输出:torch.Size([1, 2, 256, 256]) ``` 这是一个简单的基于Swin Transformer的图像分割示例,其中使用了一个SwinUnet模型,该模型包含了Swin Transformer作为编码器和一个简单的解码器。你可以根据自己的需求进行模型的修改和训练。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

彭祥.

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值