ConviT中GPSA位置注意力

def get_rel_indices(self, num_patches: int) -> torch.Tensor:
        img_size = int(num_patches ** .5)
        rel_indices = torch.zeros(1, num_patches, num_patches, 3)
        ind = torch.arange(img_size).view(1, -1) - torch.arange(img_size).view(-1, 1)
        indx = ind.repeat(img_size, img_size)
        indy = ind.repeat_interleave(img_size, dim=0).repeat_interleave(img_size, dim=1)
        indd = indx ** 2 + indy ** 2
        rel_indices[:, :, :, 2] = indd.unsqueeze(0)
        rel_indices[:, :, :, 1] = indy.unsqueeze(0)
        rel_indices[:, :, :, 0] = indx.unsqueeze(0)
        device = self.qk.weight.device
        return rel_indices.to(device)

首先由torch.arange(img_size).view(1,-1)  - torch.arange(img_size).view(-1,1)

产生绝对位置编码如[[0,1,2,3,4,5,6,7,8,9,10,11,12,13]

                                 [-1,0,1,2,3,4,5,6,7,8,9,10,11,12]

                                 [-2,-1,0,1,2,3,4,5,6,7,8,9,10,11]

                                                         ...

                                                         ...

                                 [-13,-12,-11,-10,-9,-8,-7,-6,-5,-4,-3,-2,-1,0]

然后用repeat函数对绝对位置进行重复产生N*2的位置编码

[[0,1,2,3,4,5,6,7,8,9,10,11,12,13,0,1,2,3,4,5,6,7,8,9,10,11,12,13,0,1,2,3,4,5,6,7,8,9,10,11,12,13...]

...

对两个维度进行同样的操作

再用repeat_interleave函数对绝对位置进行重复,产生N*2的位置编码

[[0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,1,1,1,1,...]

[-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,...]

...]

第三种编码方式是由上面两种方式组合起来的indd = indx**2 + indy**2

然后将三种编码方式cat起来,通过一个映射将3通道映射成num_heads个数,这么做的原因是,多头注意力要进行head个头数的注意力,需要head个注意力矩阵,同个将三通道的位置矩阵映射成heads个,然后reshape成和多头注意力矩阵形状相同的矩阵,以便和多头注意力矩阵进行结合。

  • 2
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值