[原理理解] Swin Transformer相对位置编码理解

简述

在看Swin Transformer的时候,一开始在相对位置编码这一块的理解上卡壳了挺久,也没有充分理解为什么这么做,在这记录一下自己的一些理解,以防之后忘记。

相对位置编码的意义

GPT : 用来表示一个像素或特征点相对于另一个像素或特征点的位置关系。在处理窗口(window)或局部区域时,计算相对位置索引可以帮助模型更好地捕捉局部结构和上下文信息。

直观理解

注意力

例如现在是2x2的像素窗口,想要计算他们的相对位置关系,那怎么计算?首先需要先理解一下多头自注意力机制是在搞啥子,用语言的理解就是,当前这一句话和自己的关系(简单理解)。例如下面这句话形成的注意力第一行就是“我"和 “我”、“在”、“吃”、“饭” 四个字的注意力关系。
在这里插入图片描述
以图像像素理解:由于像素是二维形式,存在行列关系,因此通常需要裁减成小窗口(windows),计算注意力关系,小窗口又想像一维计算注意力这么方便,那只能把像素进行平铺,以2x2窗口为例,需要平铺成4x1计算他们的注意力,跟上图类似,“我在吃饭”,可以理解为一个2x2窗口。

相对位置获取必要性

由于我们在像素上进行平铺,我们想要在注意力上加上位置信息,为啥要这样做?我的理解是像素平铺的方式,把原本像素的间隔拉大了,可能会加大网络学习的难度,以下图为例,左手左脚本来上下仅间隔一个像素,有强位置约束关系,但经过像素平铺为一维后,间隔变大,可能会比较难找到两者的关系。所以需要对位置进行编码,让网络知道左脚和左手位置相近。
在这里插入图片描述

当前位置初步获取

由于像素本身存在行列关系,因此使用行列进行位置编码是最合适的方式。首先使用torch.arangetorch.meshgridtorch.stack 函数形成 行列坐标,然后将行列坐标使用torch.flatten平铺成一维。这个时候整个坐标尺寸是[2,4] ,其中2代表x、y两个坐标,4代表4个像素,如图所示,按照顺序分别对应(0,0),(0,1),(1,0),(1,1)四个像素坐标。
在这里插入图片描述
代码如下:

import torch
## 一步步理解,以2x2的window size 为例

# 步骤1 得到当前window 下的xy坐标
window_size=(2,2)
coords_h = torch.arange(window_size[0]) # 0-1行
coords_w = torch.arange(window_size[1]) #0-1列
coords = torch.meshgrid([coords_h, coords_w]) #形成两个坐标,分别对应 行、列
coords = torch.stack(coords) ## -> 2*(wh, ww) #将两个坐标堆叠起来,得到某个位置的xy坐标
print("coords shape:",coords.shape) #torch.Size([2, 2, 2]),第一维的2代表x,y坐标
print(coords)
## 步骤2,将纵坐标h,横坐标w,平摊,做成2维张量
coords_flatten = torch.flatten(coords, 1) 
print(coords_flatten.shape)# torch.Size([2, 4]),第一维的2代表x,y坐标
print(coords_flatten) # 2, Wh*Ww ,整个窗口所有的h,w索引

利用广播机制获取相对位置索引XY

现在,想获取当前像素和其他像素的相对位置,应该怎么操作?可以直接利用广播机制,列扩展维度作为当前像素位置,行扩展维度作为其他像素的位置,两者相减得到相对xy坐标。 图1,第一行的每一列都是第一个像素;图2,第一行的4列对应4个像素位置;两者相减得到图3,第一行代表是第一个像素和所以4个像素的相对位置关系。
在这里插入图片描述
代码:

## 步骤3,利用广播机制,得到相对位置,举例图
relative_coords_first = coords_flatten[:, :, None]  # 2, wh*ww, 1   # 当前窗口扩展列
relative_coords_second = coords_flatten[:, None, :] # 2, 1, wh*ww   #当前窗口扩展行
relative_coords = relative_coords_first - relative_coords_second # 最终得到 2, wh*ww, wh*ww 

relative_coords = relative_coords.permute(1, 2, 0).contiguous() #为u都变换,变成Wh*Ww, Wh*Ww, 2,相对坐标
print(relative_coords[:,:,0])
print(relative_coords[:,:,1])

获取最后相对位置1

现在,我们想要获取非负数的位置索引,怎么做呢?首先我们需要先知道相对位置最小,最大值是多少?
最大值就是当前像素是第一个像素的时候最后一个像素的位置(windowsize -1 , windowsize -1)
最小值就是当前像素是最后一个像素时候第一个像素的位置(-(windowsize-1) ,-(windowsize-1))
因此,对负数进行偏移需要X、Y 各自加上 windowsize-1
在这里插入图片描述
现在,我们已经获取到非负的xy相对位置索引,需要做最后一个步骤,把两个索引映射成单一的维度的索引。能想到的最简单方式就是x+y,但是这个方式是不行的。如下图所示,如果直接两者相加,那么针对同一个像素,其他像素跟他的相对位置索引就会重复。例如第一个像素 和 (第二个像素、第三个像素)索引位置都是1.
在这里插入图片描述

获取最后相对位置2

那么,需要使用什么计算方法,才能让二维索引映射成单维索引呢?回想起二维数组 reshape成一维数组,其他行的单索引是 y*len + x,其中len是列数目,也就是一行有几个数。在相对位置索引中,一行最多的数目又是多少呢?
原本第一个元素和最后一个元素的索引值相差最多,分别是 windowsize - 1-(windowsize -1)如下图所示。
在这里插入图片描述
也就是每一行的值范围是[(-(windowsize -1)) , (windowsize-1) ] ,在上面步骤中,我们让偏移的最小值变成0,也就是索引值范围是 [0 , 2*windowsize -2],总的有 2*windowsize -1 个数目,所以下一行第一个索引的值是,y*(2*windowsize -1)+ x
想象一下,如果现在有一个3x3 数组,那么第二行第一个元素的索引是不是 1*3 + 0 = 3
在这里插入图片描述
这就是为啥相对位置索引要乘以 2 * window_size[1] - 1。
具体代码如下

relative_coords[:, :, 0] += window_size[0] - 1 # 
relative_coords[:, :, 1] += window_size[1] - 1
print(relative_coords[:,:,0])
print(relative_coords[:,:,1])
print(relative_coords[:,:,0]+relative_coords[:,:,1])
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
print(relative_position_index)

最终的相对位置值嵌入

具体的相对位置值加入注意力并不是直接依靠这个索引,而是创建一个可学习参数的table,利用上面的位置索引到这个table里面去找相应值。第一个元素和最后一个元素位置相差最多,正向距离是 windowsize -1 ,反向距离是 -(windowsize -1) ,在加上本身相对位置是0,所以总的相对位置有(2 * window_size[0] - 1) * (2 * window_size[1] - 1) 个值,而不是window_size[0]*window_size[1] * window_size[0]*window_size[1]个,这也是这个可学习table的维度。
最后根据相对位置索引去找table的值就可以啦

relative_position_bias_table = torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
print(relative_position_bias_table.shape)
print(relative_position_bias_table[relative_position_index.view(-1)].shape) #16x6
  • 8
    点赞
  • 20
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
基于YOLOv9实现工业布匹缺陷(破洞、污渍)检测系统python源码+详细运行教程+训练好的模型+评估 【使用教程】 一、环境配置 1、建议下载anaconda和pycharm 在anaconda中配置好环境,然后直接导入到pycharm中,在pycharm中运行项目 anaconda和pycharm安装及环境配置参考网上博客,有很多博主介绍 2、在anacodna中安装requirements.txt中的软件包 命令为:pip install -r requirements.txt 或者改成清华源后再执行以上命令,这样安装要快一些 软件包都安装成功后才算成功 3、安装好软件包后,把anaconda中对应的python导入到pycharm中即可(不难,参考网上博客) 二、环境配置好后,开始训练(也可以训练自己数据集) 1、数据集准备 需要准备yolo格式的目标检测数据集,如果不清楚yolo数据集格式,或者有其他数据训练需求,请看博主yolo格式各种数据集集合链接:https://blog.csdn.net/DeepLearning_/article/details/127276492 里面涵盖了上百种yolo数据集,且在不断更新,基本都是实际项目使用。来自于网上收集、实际场景采集制作等,自己使用labelimg标注工具标注的。数据集质量绝对有保证! 本项目所使用的数据集,见csdn该资源下载页面中的介绍栏,里面有对应的下载链接,下载后可直接使用。 2、数据准备好,开始修改配置文件 参考代码中data文件夹下的banana_ripe.yaml,可以自己新建一个不同名称的yaml文件 train:训练集的图片路径 val:验证集的图片路径 names: 0: very-ripe 类别1 1: immature 类别2 2: mid-ripe 类别3 格式按照banana_ripe.yaml照葫芦画瓢就行,不需要过多参考网上的 3、修改train_dual.py中的配置参数,开始训练模型 方式一: 修改点: a.--weights参数,填入'yolov9-s.pt',博主训练的是yolov9-s,根据自己需求可自定义 b.--cfg参数,填入 models/detect/yolov9-c.yaml c.--data参数,填入data/banana_ripe.yaml,可自定义自己的yaml路径 d.--hyp参数,填入hyp.scratch-high.yaml e.--epochs参数,填入100或者200都行,根据自己的数据集可改 f.--batch-size参数,根据自己的电脑性能(显存大小)自定义修改 g.--device参数,一张显卡的话,就填0。没显卡,使用cpu训练,就填cpu h.--close-mosaic参数,填入15 以上修改好,直接pycharm中运行train_dual.py开始训练 方式二: 命令行方式,在pycharm中的终端窗口输入如下命令,可根据自己情况修改参数 官方示例:python train_dual.py --workers 8 --device 0 --batch 16 --data data/coco.yaml --img 640 --cfg models/detect/yolov9-c.yaml --weights '' --name yolov9-c --hyp hyp.scratch-high.yaml --min-items 0 --epochs 500 --close-mosaic 15 训练完会在runs/train文件下生成对应的训练文件及模型,后续测试可以拿来用。 三、测试 1、训练完,测试 修改detect_dual.py中的参数 --weights,改成上面训练得到的best.pt对应的路径 --source,需要测试的数据图片存放的位置,代码中的test_imgs --conf-thres,置信度阈值,自定义修改 --iou-thres,iou阈值,自定义修改 其他默认即可 pycharm中运行detect_dual.py 在runs/detect文件夹下存放检测结果图片或者视频 【特别说明】 *项目内容完全原创,请勿对项目进行外传,或者进行违法等商业行为! 【备注】 1、该资源内项目代码都经过测试运行成功,功能ok的情况下才上传的,请放心下载使用!有问题请及时沟通交流。 2、适用人群:计算机相关专业(如计科、信息安全、数据科学与大数据技术、人工智能、通信、物联网、自动化、电子信息等)在校学生、专业老师或者企业员工下载使用。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

RichardCV

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值