Coatnet中的Rel-Attention

Coatnet中的Rel-Attention

输入:img.shape= [batch_size,(H*W),C]

multi-attention

生成qkv的过程如下:

请添加图片描述

原本的transformer中qk会除以middle_dim的1/2次方,再进行softmax。先不着急,Coatnet中qk还得加上一个relative矩阵再进行其他操作。

relative_attention

在原论文中是产生一个自定义矩阵,大小为[2*H-1,2*W-1]的矩阵,但是需要的应该是一个【H*W,H*W】的矩阵,因为原图中的每两个点之间都应该有一个映射权重P((i,j)–> (i’,j’)),论文中是相对位移相同的两点之间的权重P相同,例如(1,1)到(2,2)的映射权重与(3,3)到(4,4)的映射权重是使用同一个权重

接下来就形象的看一下例子是怎么做的:

输入图片大小3*3

请添加图片描述

创造的权重矩阵大小【2*3-1,2*3-1】

请添加图片描述

然后你把原图想象成卷积核,权重矩阵是被卷积图片,步进stride=1

输出的大小跟输入图片大小一样都是【3,3】

所以【H*W,H*W】的权重矩阵怎么来的?就是将每一步进行相乘的矩阵拼接起来:

请添加图片描述

输入是3*3的图,输出也是3*3的图,点到点的映射权重都能在上图【H*W,H*W】的权重矩阵中找到。

那代码中怎么实现的呢,代码中是先自定义一个大小为【(2*H-1)*(2*W-1),heads】的矩阵,然后经过各种骚操作生成【1,heads,(H*W),(H*W)】的矩阵

看代码(我看了很久才看懂):

__init__()部分:
self.relative_bias_table = nn.Parameter(
torch.zeros((2 * self.ih - 1) * (2 * self.iw - 1), heads))

coords = torch.meshgrid((torch.arange(self.ih), torch.arange(self.iw)))
coords = torch.flatten(torch.stack(coords), 1)
relative_coords = coords[:, :, None] - coords[:, None, :]

relative_coords[0] += self.ih - 1
relative_coords[1] += self.iw - 1
relative_coords[0] *= 2 * self.iw - 1
relative_coords = rearrange(relative_coords, 'c h w -> h w c')
relative_index = relative_coords.sum(-1).flatten().unsqueeze(1)
self.register_buffer("relative_index", relative_index)
forward()部分:
 relative_bias = self.relative_bias_table.gather(
            0, self.relative_index.repeat(1, self.heads))
        relative_bias = rearrange(
            relative_bias, '(h w) c -> 1 c h w', h=self.ih*self.iw, w=self.ih*self.iw)

融合Attention

将得到的qk矩阵和relative矩阵直接相加就好了(batch_size不用管),但注意qk矩阵与relative矩阵每个点的意义要相同,例如qk矩阵中是某个位置的权重是(1,1)到(2,2)的映射权重,那么relative矩阵对应位置也得是(1,1)到(2,2)的映射权重,注意生成relative矩阵变维度的方式。

两个矩阵相加后就是除以middle_dim的1/2次方,再softmax,最后乘以multi-attention中出来的V矩阵。

  • 3
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 4
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值