讲解下swin transformer attention mask生成的核心

 本文启发来源于多篇博客。文末附有一些链接

h_slices = (slice(0, -window_size),
            slice(-window_size, -shift_size),
            slice(-shift_size, None))
w_slices = (slice(0, -window_size),
            slice(-window_size, -shift_size),
            slice(-shift_size, None))
cnt = 0
for h in h_slices:
    for w in w_slices:
        img_mask[:, h, w, :] = cnt
        cnt += 1
 
mask_windows = window_partition(img_mask, window_size)  # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, window_size * window_size)
 
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))

上面的几行的目的是生成一个类似于这样的矩阵

         [[0., 0., 0., 0., 1., 1., 2., 2.],
         [0., 0., 0., 0., 1., 1., 2., 2.],
         [0., 0., 0., 0., 1., 1., 2., 2.],
         [0., 0., 0., 0., 1., 1., 2., 2.],
         [3., 3., 3., 3., 4., 4., 5., 5.],
         [3., 3., 3., 3., 4., 4., 5., 5.],
         [6., 6., 6., 6., 7., 7., 8., 8.],
         [6., 6., 6., 6., 7., 7., 8., 8.]]

后来在partition了下变成类似这样
        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [1., 1., 2., 2., 1., 1., 2., 2., 1., 1., 2., 2., 1., 1., 2., 2.],
         [3., 3., 3., 3., 3., 3., 3., 3., 6., 6., 6., 6., 6., 6., 6., 6.],
         [4., 4., 5., 5., 4., 4., 5., 5., 7., 7., 8., 8., 7., 7., 8., 8.]]

维度是4 ✖49

我想讲的核心是这个

attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)

这里分别在第一维度第二维度扩充

4*49*1-4*1*49维,会自动广播成4*49*49。

我重点就说这里,因为前面计算attention Q✖V后

我们也会得到一个49*49的矩阵。他是从7*7窗口展平后的两个向量进行乘法计算的。

而这里的广播分别是行和列的广播,即行复制和列复制。

我拿几个元素举例,需要从矩阵乘法去思考。

第一行第一列的元素代表相乘时的第一个元素和向量2的第一个元素,二者相减会为0也就意味着他们来自同一个区,可以计算attention

第一行第二列的元素代表相乘时的第一个元素和向量2的第二个元素,二者相减会为0也就意味着他们来自同一个区,可以计算attention

而这样的相减如果有一个元素不为0,从他是第几行第几列就能看出他来源于第一个向量的第几个元素和第二个向量的第几个元素,这个位置相乘的结果不能做attention

Swin-Transformer网络结构详解

【深度学习】详解 Swin Transformer (SwinT)

史上最详细的Swin-Transformer 掩码机制(mask of window attentation)————shaoshuai

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值