transformer中相对位置编码理解

对于一副图像,位置信息占有非常重要的地位,ViT中用了绝对位置编码,Swin中用到了相对位置编码。看了Swin的源码,参考了https://blog.csdn.net/qq_37541097/article/details/121119988?spm=1001.2014.3001.5502

博主的博客,有了自己的一点点理解,在这里跟大家分享一下。

一幅图像中,每个像素有自己的绝对位置,也有相对于其他像素的相对位置,上图展示了每个像素相对于其他像素的相对位置。对于大小为H*W的特征图,每个像素相对于其他像素的相对位置最大为H-1,最小为1-H,构建相对位置的代码如下:

coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing="ij"))  # [2, Mh, Mw]
coords_flatten = torch.flatten(coords, 1)  # [2, Mh*Mw]
# [2, Mh*Mw, 1] - [2, 1, Mh*Mw]
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # [2, Mh*Mw, Mh*Mw]
relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # [Mh*Mw, Mh*Mw, 2]

 torch.meshgrid的作用是将coords_h和coords_w构造成网格,形成坐标,就相当于生成了绝对位置坐标。然后将绝对位置展平,然后通过

# [2, Mh*Mw, 1] - [2, 1, Mh*Mw]

relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # [2, Mh*Mw, Mh*Mw]

将绝对位置信息转化为相对位置信息。具体做法是将展平后的张量拓展最后一个维度减去张量拓展第一个维度,生成上图的相对位置索引。

但是生成的相对位置索引会有负值,将二维索引变成一维如果将两个维度直接相加,则(-1,0)和(0,-1)是相同的,但是这两个相对位置是不同的。为了解决这个问题,作者用的如下方法:

relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1)  # [Mh*Mw, Mh*Mw]
self.register_buffer("relative_position_index", relative_position_index)

将第一维加上self.window_size[0] - 1 第二维加上self.window_size[1] - 1 这样每一维最小为0,最大值为2*self.window_size[0] - 2,然后第一维乘上2*self.window_size[1] - 1,那么现在最大为

(2*self.window_size[0] - 2)*(2*self.window_size[1] - 1) + (2*self.window_size[1]-2) 最小为0

其中self.window_size[0] 可以= self.window_size[1],那么从0到最大值,那就是最大值加1个数,化简后就变为(2*self.window_size[0]-1)(2*self.window_size[1]-1),那么创建的相对位置偏执table大小为(2*self.window_size[0]-1)(2*self.window_size[1]-1)。根据索引,在tabel中找到对应的偏执,然后偏执是可以更新的。

创建偏执table的代码如下:

self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) 


# [2*Mh-1 * 2*Mw-1, nH]
#对于长度为Wh的窗口,相对位置最高为Wh-1,最低为1-Wh,为了不出现负数,做如下处理:
#先将横纵坐标都加上Wh-1,那么横纵坐标最小为0,最大为2*Wh-2,再将其中一维乘以(2*Wh-1),然后横纵坐标相加,这样就可以将所有的位置坐标区分开,
#最后处理的结果中,最小值为0,最大值为(2*Wh-2)*(2*Wh-1)+(2*Wh-2),那么需要创建最大值加一个相对位置参数,因为是从0开始的,化简一下就是(2*Wh-1)*(2*Wh-1)

目前来说相对位置偏执代码就写好了,只需要在求得注意力分数后,将偏执加上就可以了。

  • 2
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值