该篇文章,是我解析 Swin transformer 论文原理(结合pytorch版本代码)所记,图片来源于源paper或其他相应博客。
代码也非原始代码,而是从代码里摘出来的片段,配上简单数据,以便理解。
当然,也可能因为设置数据不当,造成误解,请多指教。
刚写了一部分。先发布。希望多多指正。
Figure 1.
(a) The proposed Swin Transformer builds hierarchical feature maps by merging image patches (shown in gray) in deeper layers ,
and has linear computation complexity to input image size due to computation of self-attention only within each local window (shown in red).
It can thus serve as a general-purpose backbone for both image classification and dense recognition tasks.
(b) In contrast, previous vision Transformers produce feature maps of a single low resolution and have quadratic computation complexity to input image size due to computation of self attention globally.
模型结构图
Figure 3.
(a) The architecture of a Swin Transformer (Swin-T);
(b) two successive Swin Transformer Blocks (notation presented with Eq. (3)).
W-MSA and SW-MSA are multi-head self attention modules with regular and shifted windowing configurations, respectively.
Stage 1 – Patch Embedding
It first splits an input RGB image into non-overlapping patches by a patch splitting module, like ViT.
Each patch is treated as a “token” and its feature is set as a concatenation of the raw pixel RGB values.
In our implementation, we use a patch size of 4×4 and thus the feature dimension of each patch is 4×4×3 = 48.(channel–3)
A linear embedding layer is applied on this raw-valued feature to project it to an arbitrary dimension (denoted as C).
这个表述,linear embedding layer,我感觉不太准确,但是,后半部分比较准确,哈哈,将channel–3变成了96.
Several Transformer blocks with modified self-attention computation (Swin Transformer blocks) are applied on these patch tokens.
The Transformer blocks maintain the number of tokens (H/4 × W/4), and together with the linear embedding are referred to as “Stage 1”.
代码
以下代码来自于model.py:
class PatchEmbed(nn.Module):
"""
2D Image to Patch Embedding
"""
"""
@ time : 2024/12/17
"""
import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
class PatchEmbed(nn.Module):
"""
2D Image to Patch Embedding
"""
def __init__(self, patch_size=4, in_c=3, embed_dim=96, norm_layer=None):
super().__init__()
patch_size = (patch_size, patch_size)
self.patch_size = patch_size
self.in_chans = in_c
self.embed_dim = embed_dim
self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
_, _, H, W = x.shape
# padding
# 如果输入图片的H,W不是patch_size的整数倍,需要进行padding
pad_input = (H % self.patch_size[0] != 0) or (W % self.patch_size[1] != 0)
if pad_input:
# to pad the last 3 dimensions,
# (W_left,W_right, H_top,H_bottom, C_front,C_back)
x = F.pad(x,
(0, self.patch_size[1] - W % self.patch_size[1],
0, self.patch_size[0] - H % self.patch_size[0],
0, 0
)
)
# 下采样patch_size倍
x = self.proj(x)
_, _, H, W = x.shape
# flatten: [B, C, H, W] -> [B, C, HW]
# transpose: [B, C, HW] -> [B, HW, C]
x = x.flatten(2).transpose(1, 2)
x = self.norm(x)
print(x.shape)
# torch.Size([1, 3136, 96])
# 224/4 * 224/4 = 3136
return x, H, W
if __name__ == '__main__':
img_path = "tulips.jpg"
img = Image.open(img_path)
plt.imshow(img)
# [N, C, H, W]
print(img.size)
# (500,375)
#
img_size = 224
data_transform = transforms.Compose(
[transforms.Resize(int(img_size * 1.14)),
transforms.CenterCrop(img_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]
)
img