Relative Positional Bias -- [Swin-transformer]

论文中对于这一块的描述不是很清楚,特意记录一下学习过程。
这篇博客讲解的很清楚,请参考阅读https://blog.csdn.net/qq_37541097/article/details/121119988
在这里插入图片描述

以下通过代码形式运行一个demo帮助理解。

1.假设window的H,W均为2,首先构造一个二维坐标
x= torch.arange(2)
y= torch.arange(2)

#输入为一维序列,输出两个二维网格,常用来生成坐标
ox,oy = torch.meshgrid([x,y])

#按照某个维度拼接,输入序列shape必须一致,默认按照dim0
o2 = torch.stack((ox,oy))

print(ox,oy)
print(o2,o2.shape)

coords = torch.flatten(o2,1)
print(coords,coords.shape)

输出

tensor([[0, 0],
        [1, 1]]) tensor([[0, 1],
        [0, 1]])
tensor([[[0, 0],
         [1, 1]],

        [[0, 1],
         [0, 1]]]) torch.Size([2, 2, 2])
#得到2行序列,对应x,y轴的坐标    
tensor([[0, 0, 1, 1],
        [0, 1, 0, 1]]) 
 torch.Size([2, 4])
   

计算相对坐标索引时,采用了一种我之前没见过的扩张维度的方法,简介高效
print(coords[:,:,None].shape) #相当于增加一个维度
print(coords[:,None,:],coords[:,None,:].shape)
print(coords[:,None,:,None].shape)
#作用与unsqueeze()相同
coords.unsqueeze(1)==coords[:,None,:]

输出

torch.Size([2, 4, 1])
tensor([[[0, 0, 1, 1]],

        [[0, 1, 0, 1]]]) 
        
torch.Size([2, 1, 4])

torch.Size([2, 1, 4, 1])

tensor([[[True, True, True, True]],

        [[True, True, True, True]]])
print(coords[:,:,None]) #相当于增加一个维度
print(coords[:,None,:])

输出

tensor([[[0],
         [0],
         [1],
         [1]],

        [[0],
         [1],
         [0],
         [1]]])
tensor([[[0, 0, 1, 1]],

        [[0, 1, 0, 1]]])
tensor([[[True, True, True, True]],

        [[True, True, True, True]]])
2.计算相对索引
relative_coords=coords[:,:,None]-coords[:,None,:]  #(2,16,1)-(2,1,16)  #广播机制相减
print(f"relative_coords:{relative_coords.shape}={coords[:,:,None].shape}-{coords[:,None,:].shape }","\n",{relative_coords})

输出

#这里相减,应该是使用了广播机制,先扩展到相同shape后,再进行元素相减运算
relative_coords:torch.Size([2, 4, 4])=torch.Size([2, 4, 1])-torch.Size([2, 1, 4]) 
 {tensor([[[ 0,  0, -1, -1],
         [ 0,  0, -1, -1],
         [ 1,  1,  0,  0],
         [ 1,  1,  0,  0]],

        [[ 0, -1,  0, -1],
         [ 1,  0,  1,  0],
         [ 0, -1,  0, -1],
         [ 1,  0,  1,  0]]])}

转换为[4,4,2],相当于得到4个4*2的坐标对,一行横坐标,一行纵坐标

relative_coords=relative_coords.permute(1,2,0).contiguous()
print(relative_coords)

输出

torch.Size([4, 4, 2])
tensor([[[ 0,  0],
         [ 0, -1],
         [-1,  0],
         [-1, -1]],

        [[ 0,  1],
         [ 0,  0],
         [-1,  1],
         [-1,  0]],

        [[ 1,  0],
         [ 1, -1],
         [ 0,  0],
         [ 0, -1]],

        [[ 1,  1],
         [ 1,  0],
         [ 0,  1],
         [ 0,  0]]])


print(relative_coords[:,:,0])  #输出第一列元素对应输入中第一列的第1个元素集合 ,第二列对应输入第一列的第2个元素集合
print(relative_coords[:,:,1])

输出

tensor([[ 0,  0, -1, -1],
        [ 0,  0, -1, -1],
        [ 1,  1,  0,  0],
        [ 1,  1,  0,  0]])
tensor([[ 0, -1,  0, -1],
        [ 1,  0,  1,  0],
        [ 0, -1,  0, -1],
        [ 1,  0,  1,  0]])
window_size=(2,2)

#行、列元素都加上M-1 ,这里M=2
relative_coords[:, :, 0] += window_size[0] - 1  # shift to start from 0
print(relative_coords)
relative_coords[:, :, 1] += window_size[1] - 1
print(relative_coords)


relative_coords[:, :, 0] *= 2 * window_size[1] - 1
print(relative_coords)
relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
print(relative_position_index)

输出

#第一列(行)加M-1
tensor([[[ 1,  0],
         [ 1, -1],
         [ 0,  0],
         [ 0, -1]],

        [[ 1,  1],
         [ 1,  0],
         [ 0,  1],
         [ 0,  0]],

        [[ 2,  0],
         [ 2, -1],
         [ 1,  0],
         [ 1, -1]],

        [[ 2,  1],
         [ 2,  0],
         [ 1,  1],
         [ 1,  0]]])
# 继续第2列 (列) 加M-1
tensor([[[1, 1],
         [1, 0],
         [0, 1],
         [0, 0]],

        [[1, 2],
         [1, 1],
         [0, 2],
         [0, 1]],

        [[2, 1],
         [2, 0],
         [1, 1],
         [1, 0]],

        [[2, 2],
         [2, 1],
         [1, 2],
         [1, 1]]])
#第一列 (行) 乘 2M-1(3)
tensor([[[3, 1],
         [3, 0],
         [0, 1],
         [0, 0]],

        [[3, 2],
         [3, 1],
         [0, 2],
         [0, 1]],

        [[6, 1],
         [6, 0],
         [3, 1],
         [3, 0]],

        [[6, 2],
         [6, 1],
         [3, 2],
         [3, 1]]])
#行列元素相加
tensor([[4, 3, 1, 0],
        [5, 4, 2, 1],
        [7, 6, 4, 3],
        [8, 7, 5, 4]])

这里就得到相对位置索引,这里对应的值需要到relative positional bias Table 中获取,一开始程序中就定一个了一个可学习的table,长度为[2M-1]*[2M-1], 这里M=2,也就是长度为9,正对应上边索引0-8

# define a parameter table of relative position bias
        #构造可学习的相对位置偏置table,长度为 (2H-1)*(2W-1)*(num_head)  
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH

这里假设有两个attention头

from torch import nn
from timm.models.layers import DropPath, to_2tuple, trunc_normal_

relative_position_bias_table = nn.Parameter(
   torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), 2))  # 2*Wh-1 * 2*Ww-1, nH 假设有两个attn头
print(relative_position_bias_table.shape,"\n",relative_position_bias_table)
trunc_normal_(relative_position_bias_table, std=.02) #初始化bias_table

输出

torch.Size([9, 2])  #两个attn头,每个头(2M-1)*(2M-1)个数
 Parameter containing:
tensor([[0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.]], requires_grad=True)
Parameter containing:  #初始化后的数据
tensor([[-0.0340,  0.0181],
        [-0.0033, -0.0055],
        [ 0.0045,  0.0193],
        [ 0.0412, -0.0031],
        [ 0.0004, -0.0032],
        [ 0.0201, -0.0161],
        [ 0.0067,  0.0079],
        [ 0.0241, -0.0279],
        [-0.0125, -0.0291]], requires_grad=True)
relative_position_bias = relative_position_bias_table[relative_position_index.view(-1)]

print("index :\n",relative_position_index.view(-1).shape,"\n",relative_position_index.view(-1))
print("bias table 根据索引取值后的数据:\n",relative_position_bias.shape,"\n",relative_position_bias)


relative_position_bias=relative_position_bias.view(window_size[0] * window_size[1], window_size[0] * window_size[1], -1)  # Wh*Ww,Wh*Ww,nH
print("维度变换:\n",relative_position_bias.shape,"\n",relative_position_bias) 
     
#转换为与attention shape一致
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() 
index :
 torch.Size([16]) 
 tensor([4, 3, 1, 0, 5, 4, 2, 1, 7, 6, 4, 3, 8, 7, 5, 4])  #索引展开成一维
bias table 根据索引取值后的数据:
 torch.Size([16, 2]) 
 tensor([[ 0.0004, -0.0032],
        [ 0.0412, -0.0031],
        [-0.0033, -0.0055],
        [-0.0340,  0.0181],
        [ 0.0201, -0.0161],
        [ 0.0004, -0.0032],
        [ 0.0045,  0.0193],
        [-0.0033, -0.0055],
        [ 0.0241, -0.0279],
        [ 0.0067,  0.0079],
        [ 0.0004, -0.0032],
        [ 0.0412, -0.0031],
        [-0.0125, -0.0291],
        [ 0.0241, -0.0279],
        [ 0.0201, -0.0161],
        [ 0.0004, -0.0032]], grad_fn=<IndexBackward>)
维度变换:
 torch.Size([4, 4, 2]) 
 tensor([[[ 0.0004, -0.0032],
         [ 0.0412, -0.0031],
         [-0.0033, -0.0055],
         [-0.0340,  0.0181]],

        [[ 0.0201, -0.0161],
         [ 0.0004, -0.0032],
         [ 0.0045,  0.0193],
         [-0.0033, -0.0055]],

        [[ 0.0241, -0.0279],
         [ 0.0067,  0.0079],
         [ 0.0004, -0.0032],
         [ 0.0412, -0.0031]],

        [[-0.0125, -0.0291],
         [ 0.0241, -0.0279],
         [ 0.0201, -0.0161],
         [ 0.0004, -0.0032]]], grad_fn=<ViewBackward>)

在这里插入图片描述

以上代码就是有关相对位置偏置的全部内容了。

  • 2
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值