SwinTransformer中SW-MSA中attn_mask生成逻辑纪录

input_resolution = (12, 12)
window_size = 6
shift_size = 3

生成部分的源码如下:

		if self.shift_size > 0:
            # calculate attention mask for SW-MSA
            H, W = self.input_resolution
            img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1
            # 对 [H, W] 大小进行分区
            # 分区的目的在于,shift之后,进行window划分时,
            # 一个window内包含多个区域,可能彼此不相临,需要进行标号区分
            # 数字相同表示在shift之前区域相邻
            h_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            w_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            
            """
            # 也可以按如下分区
            # 数字相同表示在shift之前区域相邻
            h_slices = (slice(0, -self.shift_size),
                        slice(-self.shift_size, None))
            w_slices = (slice(0, -self.shift_size),
                        slice(-self.shift_size, None))
            """
            
            cnt = 0
            for h in h_slices:
                for w in w_slices:
                    img_mask[:, h, w, :] = cnt
                    cnt += 1

			# windows 划分
            mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1
            mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
            # attn_mask 计算
            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
            attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
        else:
            attn_mask = None

window实际划分时不是按照中线四等份划分的
更新:SwinTransformer中下采样非四等份中线划分,window划分应该是按照中线四等份划分的。
在这里插入图片描述

  • 1
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值