Swin Transformer相对位置指数代码记录

# -*- coding:utf-8 -*-
import torch 

if __name__=="__main__":
	window_size=[2,2]
	coords_h = torch.arange(window_size[0])
	print("coords_h:\n",coords_h)
	coords_w = torch.arange(window_size[1])
	print("coords_w:\n",coords_w)
	coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
	print("coords:\n",coords)
	coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
	print("coords_flatten:\n",coords_flatten)
	relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
	print("relative_coords_1:\n",relative_coords)
	relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
	print("relative_coords_2:\n",relative_coords)
	relative_coords[:, :, 0] += window_size[0] - 1  # shift to start from 0
	print("relative_coords[:, :, 0]:\n",relative_coords[:, :, 0])
	relative_coords[:, :, 1] += window_size[1] - 1
	print("relative_coords[:, :, 1]:\n",relative_coords[:, :, 1])
	relative_coords[:, :, 0] *= 2 * window_size[1] - 1
	print("relative_coords[:, :, 0]:\n",relative_coords[:, :, 0])
	relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
	print("relative_pos_index:\n",relative_position_index)

 

coords:
 tensor([[[0, 0],
         [1, 1]],

        [[0, 1],
         [0, 1]]])
coords_flatten:
 tensor([[0, 0, 1, 1],
        [0, 1, 0, 1]])
relative_coords_1:
 tensor([[[ 0,  0, -1, -1],
         [ 0,  0, -1, -1],
         [ 1,  1,  0,  0],
         [ 1,  1,  0,  0]],

        [[ 0, -1,  0, -1],
         [ 1,  0,  1,  0],
         [ 0, -1,  0, -1],
         [ 1,  0,  1,  0]]])
relative_coords_2:
 tensor([[[ 0,  0],
         [ 0, -1],
         [-1,  0],
         [-1, -1]],

        [[ 0,  1],
         [ 0,  0],
         [-1,  1],
         [-1,  0]],

        [[ 1,  0],
         [ 1, -1],
         [ 0,  0],
         [ 0, -1]],

        [[ 1,  1],
         [ 1,  0],
         [ 0,  1],
         [ 0,  0]]])
relative_coords[:, :, 0]:
 tensor([[1, 1, 0, 0],
        [1, 1, 0, 0],
        [2, 2, 1, 1],
        [2, 2, 1, 1]])
relative_coords[:, :, 1]:
 tensor([[1, 0, 1, 0],
        [2, 1, 2, 1],
        [1, 0, 1, 0],
        [2, 1, 2, 1]])
relative_coords[:, :, 0]:
 tensor([[3, 3, 0, 0],
        [3, 3, 0, 0],
        [6, 6, 3, 3],
        [6, 6, 3, 3]])
relative_pos_index:
 tensor([[4, 3, 1, 0],
        [5, 4, 2, 1],
        [7, 6, 4, 3],
        [8, 7, 5, 4]])

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值