前言
本文介绍 pytorch中两种特征裁剪函数(roi_align&grid_sample),在mask_rcnn提出利用roi_align进行目标的裁剪并进行重采样,另一种特征裁剪是利用grid_sample进行特征裁剪
一、grid_sample
pytorch中的介绍是:给定输入(input)和流场网格(flow_grid),使用网格中的输入值和像素位置来计算输出。 当前,仅支持空间(4-D)和体积(5-D)输入
其中:流场网格(flow_grid)是根据box坐标进行线性插值得到的
#输入
box_coord:[x1,y1,x1,y2,cls],
width:int,
height:int,
#输出:grid
def gen_grid(self, box_coord, width, height):
wmin, hmin, wmax, hmax = box_coord[:4]
# 间隔取点,进行线性插值,
grid_x = torch.linspace(wmin, wmax, width).view(1,1,width,1)
grid_y = torch.linspace(hmin, hmax, height).view(1,height,1,1)
grid_x = grid_x.expand(1,height,width,1)
grid_y = grid_y.expand(1,height,width,1)
grid = torch.cat((grid_x,grid_y), dim=-1)
return grid
对输入特征(feature_map)进行裁剪
roi = F.grid_sample(feature_map, grid)
注意:在进行计算flow_grid时需要先将坐标进行归一化
orm_H, norm_W = (img_H-1)/2, (img_W-1)/2
bboxes[:,[0,2]] = bboxes[:,[0,2]]*norm_W + norm_W
bboxes[:,[1,3]] = bboxes[:,[1,3]]*norm_H + norm_H
二、roi_align
原理部分可移步图解 RoIAlign 以及在 PyTorch 中的使用(含代码示例)
#输入
features:输入特征--Tensor
rois:裁剪范围--Tensor[K, 5] or List[Tensor[L, 4]]
resample:(28,28)
spatial_scale:1
Align_rois = roi_align(features,rois,resample, spatial_scale=1.0)
其中rois为[K, 5]:[index,x1,y1,x1,y2],包含裁剪坐标以及裁剪相应的特征索引。