【零基础讲论文源码】Swin-Transformer源代码阅读

目前这个系列会开两个方向, cv transformer 和OCR方向。

Transformer方向

OCR方向

  • DBnet解读【链接】(正在制作中。。。)
  • PP_OCR【链接】(待续)
  • 待续

整体介绍

Swin-transformer是微软 CVPR2021今年最近一篇非常棒的论文。
Github【源代码地址】
原文地址【地址】
先上个结构图:
在这里插入图片描述
(为方便阅读,代码进行简化)

SwinTransformer

: 主代码

#整体结构中,通过PatchEmbed()分割出图像块,再经过相应层数的BasicLayer()。
class SwinTransformer(nn.Module):
    def __init__():
        super().__init__()

        # split image into non-overlapping patches
        self.patch_embed = PatchEmbed()

        # build layers
        self.layers = nn.ModuleList()
        for i_layer in range(self.num_layers):
            layer = BasicLayer()
            self.layers.append(layer)

        self.avgpool = nn.AdaptiveAvgPool1d(1)
        self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
        #用于输出的分类维度,可以根据自己的需要更改
        
    def forward_features(self, x):
        x = self.patch_embed(x)
        # b h w c -> b (h/4)*(w/4) 16*c
        x = self.pos_drop(x)
        for layer in self.layers:
            x = layer(x)

        x = self.norm(x)  # B L C
        #以下用于进行分类等任务,可以根据需要进行调整
        x = self.avgpool(x.transpose(1, 2))  # B C 1
        x = torch.flatten(x, 1)
        return x

    def forward(self, x):
        x = self.forward_features(x)
        x = self.head(x)
        return x


PatchEmbed

:分割图像信息

class PatchEmbed(nn.Module):
    def __init__():
        super().__init__()
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        B, C, H, W = x.shape
        x = self.proj(x).flatten(2).transpose(1, 2)  # B Ph*Pw C
        #flatten(2)等于从2维度开始进行展平操作,x的维度为b c h w ,
        #设patch_size为4,则结果为 b (h/4)*(w/4) 16*c
        
        if self.norm is not None:
            x = self.norm(x)
        return x

论文这里是说用nn.unfold函数,卷而不积,self.proj的卷积+flatten(start_dim)模拟一个unfold的操作(确实卷积肯定是比较优的结果,直接unfold,反而不平滑),并通过patch_size的大小,对图像进行缩小,并分割,传出的特征为b (h/4)(w/4) 16c

#flatten举例
>>> a=torch.randn(1,2,3,4)
>>> a.flatten(2).shape
torch.Size([1, 2, 12])

BasicLayer

作为核心的stage层(文中为4层)

class BasicLayer(nn.Module):

    def __init__(self, ):

        super().__init__
评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值