Swin transformer 论文阅读记录 & 代码分析


该篇文章,是我解析 Swin transformer 论文原理(结合pytorch版本代码)所记,图片来源于源paper或其他相应博客。

代码也非原始代码,而是从代码里摘出来的片段,配上简单数据,以便理解。

当然,也可能因为设置数据不当,造成误解,请多指教。

刚写了一部分。先发布。希望多多指正。


在这里插入图片描述
Figure 1.
(a) The proposed Swin Transformer builds hierarchical feature maps by merging image patches (shown in gray) in deeper layers ,
and has linear computation complexity to input image size due to computation of self-attention only within each local window (shown in red).
It can thus serve as a general-purpose backbone for both image classification and dense recognition tasks.
(b) In contrast, previous vision Transformers produce feature maps of a single low resolution and have quadratic computation complexity to input image size due to computation of self attention globally.

模型结构图

在这里插入图片描述
Figure 3.
(a) The architecture of a Swin Transformer (Swin-T);
(b) two successive Swin Transformer Blocks (notation presented with Eq. (3)).
W-MSA and SW-MSA are multi-head self attention modules with regular and shifted windowing configurations, respectively.

Stage 1 – Patch Embedding

It first splits an input RGB image into non-overlapping patches by a patch splitting module, like ViT.

Each patch is treated as a “token” and its feature is set as a concatenation of the raw pixel RGB values.

In our implementation, we use a patch size of 4×4 and thus the feature dimension of each patch is 4×4×3 = 48.(channel–3)

A linear embedding layer is applied on this raw-valued feature to project it to an arbitrary dimension (denoted as C).
这个表述,linear embedding layer,我感觉不太准确,但是,后半部分比较准确,哈哈,将channel–3变成了96.

Several Transformer blocks with modified self-attention computation (Swin Transformer blocks) are applied on these patch tokens.

The Transformer blocks maintain the number of tokens (H/4 × W/4), and together with the linear embedding are referred to as “Stage 1”.

代码

以下代码来自于model.py:

class PatchEmbed(nn.Module):
    """
    2D Image to Patch Embedding
    """
"""
@ time : 2024/12/17
"""
import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F


class PatchEmbed(nn.Module):
    """
    2D Image to Patch Embedding
    """

    def __init__(self, patch_size=4, in_c=3, embed_dim=96, norm_layer=None):
        super().__init__()
        patch_size = (patch_size, patch_size)

        self.patch_size = patch_size
        self.in_chans = in_c
        self.embed_dim = embed_dim
        self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()

    def forward(self, x):
        _, _, H, W = x.shape

        # padding
        # 如果输入图片的H,W不是patch_size的整数倍,需要进行padding
        pad_input = (H % self.patch_size[0] != 0) or (W % self.patch_size[1] != 0)

        if pad_input:
            # to pad the last 3 dimensions,
            # (W_left,W_right, H_top,H_bottom, C_front,C_back)
            x = F.pad(x,
                      (0, self.patch_size[1] - W % self.patch_size[1],
                       0, self.patch_size[0] - H % self.patch_size[0],
                       0, 0
                       )
                      )

        # 下采样patch_size倍
        x = self.proj(x)

        _, _, H, W = x.shape

        # flatten: [B, C, H, W] -> [B, C, HW]
        # transpose: [B, C, HW] -> [B, HW, C]
        x = x.flatten(2).transpose(1, 2)

        x = self.norm(x)
        print(x.shape)
        # torch.Size([1, 3136, 96])
		# 224/4 * 224/4 = 3136
		
        return x, H, W


if __name__ == '__main__':
    img_path = "tulips.jpg"

    img = Image.open(img_path)
    plt.imshow(img)
    # [N, C, H, W]
    print(img.size)
    # (500,375)

    #
    img_size = 224
    data_transform = transforms.Compose(
        [transforms.Resize(int(img_size * 1.14)),
         transforms.CenterCrop(img_size),
         transforms.ToTensor(),
         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]
    )
    img 
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值