classSinePositionalEncoding(BaseModule):"""Position encoding with sine and cosine functions.
See `End-to-End Object Detection with Transformers
<https://arxiv.org/pdf/2005.12872>`_ for details.
Args:
num_feats (int): The feature dimension for each position
along x-axis or y-axis. Note the final returned dimension
for each position is 2 times of this value.
temperature (int, optional): The temperature used for scaling
the position embedding. Defaults to 10000.
normalize (bool, optional): Whether to normalize the position
embedding. Defaults to False.
scale (float, optional): A scale factor that scales the position
embedding. The scale will be used only when `normalize` is True.
Defaults to 2*pi.
eps (float, optional): A value added to the denominator for
numerical stability. Defaults to 1e-6.
offset (float): offset add to embed when do the normalization.
Defaults to 0.
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None
"""def__init__(self,
num_feats,
temperature=10000,
normalize=False,
scale=2* math.pi,
eps=1e-6,
offset=0.,
init_cfg=None):super(SinePositionalEncoding, self).__init__(init_cfg)if normalize:assertisinstance(scale,(float,int)),'when normalize is set,' \
'scale should be provided and in float or int type, ' \
f'found {type(scale)}'
self.num_feats = num_feats
self.temperature = temperature
self.normalize = normalize
self.scale = scale
self.eps = eps
self.offset = offset
defforward(self, mask):"""Forward function for `SinePositionalEncoding`.
Args:
mask (Tensor): ByteTensor mask. Non-zero values representing
ignored positions, while zero values means valid positions
for this image. Shape [bs, h, w].
Returns:
pos (Tensor): Returned position embedding with shape
[bs, num_feats*2, h, w].
"""# For convenience of exporting to ONNX, it's required to convert# `masks` from bool to int.
mask = mask.to(torch.int)# mask 值为bool类型,将bool类型转换为int类型
not_mask =1- mask # logical_not # 将mask中值取反, 没有padding得部分为1,padding得部分为0
y_embed = not_mask.cumsum(1, dtype=torch.float32)# 按照h 方向累加, 求出h每一个位置的坐标
x_embed = not_mask.cumsum(2, dtype=torch.float32)# 按照w 方向累加, 求出w每一个位置的坐标if self.normalize:# 是否对数据进行标准化处理, y_embed[:,-1:, :]代表h方向的最大值
y_embed =(y_embed + self.offset)/ \
(y_embed[:,-1:,:]+ self.eps)* self.scale
x_embed =(x_embed + self.offset)/ \
(x_embed[:,:,-1:]+ self.eps)* self.scale
dim_t = torch.arange(
self.num_feats, dtype=torch.float32, device=mask.device)# 生成[0, 128]的数组
dim_t = self.temperature**(2*(dim_t //2)/ self.num_feats)# 归一化
pos_x = x_embed[:,:,:,None]/ dim_t # [b, h, w, 1] -> [b, h, w, 128]
pos_y = y_embed[:,:,:,None]/ dim_t # [b, h, w, 1] -> [b, h, w, 128]# use `view` instead of `flatten` for dynamically exporting to ONNX
B, H, W = mask.size()
pos_x = torch.stack((pos_x[:,:,:,0::2].sin(), pos_x[:,:,:,1::2].cos()),
dim=4).view(B, H, W,-1)# 对偶数位置进行sin处理, 对奇数位置进行cos处理. [b, h, w, 64, 2] -> [b, h, w, 128]
pos_y = torch.stack((pos_y[:,:,:,0::2].sin(), pos_y[:,:,:,1::2].cos()),
dim=4).view(B, H, W,-1)
pos = torch.cat((pos_y, pos_x), dim=3).permute(0,3,1,2)# [b, h, w, 128] -> [b, h, w, 256] -> [b, 256, h, w]return pos