# -*- coding:utf-8 -*-
import torch
if __name__=="__main__":
window_size=[2,2]
coords_h = torch.arange(window_size[0])
print("coords_h:\n",coords_h)
coords_w = torch.arange(window_size[1])
print("coords_w:\n",coords_w)
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
print("coords:\n",coords)
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
print("coords_flatten:\n",coords_flatten)
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
print("relative_coords_1:\n",relative_coords)
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
print("relative_coords_2:\n",relative_coords)
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
print("relative_coords[:, :, 0]:\n",relative_coords[:, :, 0])
relative_coords[:, :, 1] += window_size[1] - 1
print("relative_coords[:, :, 1]:\n",relative_coords[:, :, 1])
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
print("relative_coords[:, :, 0]:\n",relative_coords[:, :, 0])
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
print("relative_pos_index:\n",relative_position_index)
coords:
tensor([[[0, 0],
[1, 1]],
[[0, 1],
[0, 1]]])
coords_flatten:
tensor([[0, 0, 1, 1],
[0, 1, 0, 1]])
relative_coords_1:
tensor([[[ 0, 0, -1, -1],
[ 0, 0, -1, -1],
[ 1, 1, 0, 0],
[ 1, 1, 0, 0]],
[[ 0, -1, 0, -1],
[ 1, 0, 1, 0],
[ 0, -1, 0, -1],
[ 1, 0, 1, 0]]])
relative_coords_2:
tensor([[[ 0, 0],
[ 0, -1],
[-1, 0],
[-1, -1]],
[[ 0, 1],
[ 0, 0],
[-1, 1],
[-1, 0]],
[[ 1, 0],
[ 1, -1],
[ 0, 0],
[ 0, -1]],
[[ 1, 1],
[ 1, 0],
[ 0, 1],
[ 0, 0]]])
relative_coords[:, :, 0]:
tensor([[1, 1, 0, 0],
[1, 1, 0, 0],
[2, 2, 1, 1],
[2, 2, 1, 1]])
relative_coords[:, :, 1]:
tensor([[1, 0, 1, 0],
[2, 1, 2, 1],
[1, 0, 1, 0],
[2, 1, 2, 1]])
relative_coords[:, :, 0]:
tensor([[3, 3, 0, 0],
[3, 3, 0, 0],
[6, 6, 3, 3],
[6, 6, 3, 3]])
relative_pos_index:
tensor([[4, 3, 1, 0],
[5, 4, 2, 1],
[7, 6, 4, 3],
[8, 7, 5, 4]])