Transformer实战-系列教程8:SwinTransformer 源码解读1(项目配置/SwinTransformer类)

🚩🚩🚩Transformer实战-系列教程总目录

有任何问题欢迎在下面留言
本篇文章的代码运行界面均在Pycharm中进行
本篇文章配套的代码资源已经上传
https://download.csdn.net/download/weixin_50592077/88809977?spm=1001.2014.3001.5501

SwinTransformer 算法原理
SwinTransformer 源码解读1(项目配置/SwinTransformer类)
SwinTransformer 源码解读2(PatchEmbed类/BasicLayer类)
SwinTransformer 源码解读3(SwinTransformerBlock类)
SwinTransformer 源码解读4(WindowAttention类)
SwinTransformer 源码解读5(Mlp类/PatchMerging类)

1、项目配置

本项目来自SwinTransformer 的GitHub官方源码:

Image Classification: Included in this repo. See get_started.md for a quick start.
Object Detection and Instance Segmentation: See Swin Transformer for Object Detection.
Semantic Segmentation: See Swin Transformer for Semantic Segmentation.
Video Action Recognition: See Video Swin Transformer.
Semi-Supervised Object Detection: See Soft Teacher.
SSL: Contrasitive Learning: See Transformer-SSL.
SSL: Masked Image Modeling: See get_started.md#simmim-support.
Mixture-of-Experts: See get_started for more instructions.
Feature-Distillation: See Feature-Distillation.

此处包含多个版本(分类、检测、分割、视频 ),但是仅仅学习算法建议选择第一个图像分类的基础版本就可以了

安装需求:

pip install timm==0.4.12
pip install yacs==0.1.8
pip install termcolor==1.1.0
pytorch
opencv
Apex(linux版本)

原本的数据是imagenet,这个数据太多了,有很多开源的imagenet小版本,本文配套的资源就是已经配好的imagenet小版本,目录信息、数据标注、数据划分都已经做好了

本项目的执行文件就是main.py,源码我已经修改了部分

配置参数:

--cfg configs/swin_tiny_patch4_window7_224.yaml
--data-path imagenet
--local_rank 0
--batch-size 4

–local rank 0这个参数表示的是分布式训练,直接用当前的这个就好

2、SwinTransformer类

打开models有两个构建模型的源码:
build.py
swin_transformer.py

构建模型的部分主要就在swin_transformer.py,一共有600多行代码

首先看SwinTransformer类的前向传播函数:

class SwinTransformer(nn.Module):
    def forward(self, x):
        x = self.forward_features(x)
        x = self.head(x)
        return x

打印这个过程的shape值:

  1. 原始输入x: torch.Size([4, 3, 224, 224]),原始输入是一张彩色图像,4是batch,3是通道数,图像是224*244的长宽
  2. self.forward_features(x):torch.Size([4, 768]),经过forward_features函数后,变成了768维的向量
  3. self.head(x):torch.Size([4, 1000]),head是一个全连接层,很显然这个1000是最后的分类数

所以整个体征提取的过程都在self.forward_features()函数中:

    def forward_features(self, x):
        x = self.patch_embed(x)
        if self.ape:
            x = x + self.absolute_pos_embed
        x = self.pos_drop(x)
        for layer in self.layers:
            x = layer(x)
        x = self.norm(x)  # B L C
        x = self.avgpool(x.transpose(1, 2))  # B C 1
        x = torch.flatten(x, 1)
        return x
  1. 原始输入x: torch.Size([4, 3, 224, 224]),原始输入是一张彩色图像

  2. patch_embed: torch.Size([4, 3136, 96]),图像经过patch_embbeding变成一个Transformer需要的序列,相当于序列是3136个向量,每个向量维度是96。这个过程通常包括将图像分割成多个patches,然后将每个patch线性投影到一个指定的维度。

  3. if self.ape: x = x + self.absolute_pos_embed,如果模型配置了绝对位置编码(self.ape为真),这行代码会将绝对位置嵌入加到patch的嵌入上。绝对位置嵌入提供了每个patch在图像中位置的信息,帮助模型理解图像中不同部分的空间关系, 不改变维度

  4. pos_drop: torch.Size([4, 3136, 96]),一层Dropout

  5. layer: torch.Size([4, 784, 192]),for循环主要是Swin Transformer Block的堆叠

  6. layer: torch.Size([4, 196, 384]),4次循环,序列长度减小

  7. layer: torch.Size([4, 49, 768]),特征图个数增多,即向量维度变大

  8. layer: torch.Size([4, 49, 768]),最后一次维度不变

  9. norm: torch.Size([4, 49, 768]),层归一化,维度不变

  10. avgpool: torch.Size([4, 768, 1]),平均池化

  11. flatten: torch.Size([4, 768]),拉平操作,去掉多余的维度

SwinTransformer 算法原理
SwinTransformer 源码解读1(项目配置/SwinTransformer类)
SwinTransformer 源码解读2(PatchEmbed类/BasicLayer类)
SwinTransformer 源码解读3(SwinTransformerBlock类)
SwinTransformer 源码解读4(WindowAttention类)
SwinTransformer 源码解读5(Mlp类/PatchMerging类)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

机器学习杨卓越

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值