RoiAlign原理就不介绍了,可参考这个链接,博主在里面介绍的已经非常清楚。
今天刚把对应的caffe2源码看了一遍,并添加了详细的注释,希望帮助大家理解,有问题欢迎指正,谢谢!
#include "roi_align_op.h"
#include "caffe2/utils/eigen_utils.h"
#include "caffe2/utils/math.h"
#ifdef CAFFE2_USE_MKL
#include "caffe2/mkl/operators/operator_fallback_mkl.h"
#endif // CAFFE2_USE_MKL
namespace caffe2 {
namespace {
template <typename T>
struct PreCalc {
int pos1;
int pos2;
int pos3;
int pos4;
T w1;
T w2;
T w3;
T w4;
};
/*
height,//输入数据的第三个维度长度即h
width,//输入数据的第四个维度长度即w
pooled_height,//pooling后的h
pooled_width,//pooling后的w
iy_upper,//每个小网格内用于pooling的垂直方向采样点数
ix_upper,//每个小网格内用于pooling的水平方向采样点数
roi_start_h,//roi在输入图像中的坐标y1变换到roialign输入featuremap的坐标,float型
roi_start_w,//roi在输入图像中的坐标x1变换到roialign输入featuremap的坐标,float型
bin_size_h,//每个roi分块后每个小块垂直方向包含的bin数量(即尺寸)
bin_size_w,//每个roi分块后每个小块水平方向包含的bin数量(即尺寸)
roi_bin_grid_h,//每个小网格内用于pooling的垂直方向采样点数
roi_bin_grid_w,//每个小网格内用于pooling的水平方向采样点数
*/
//获得双线性插值采样点周围四个坐标的索引以及对应的权重
template <typename T>
void pre_calc_for_bilinear_interpolate(
const int height,
const int width,
const int pooled_height,
const int pooled_width,
const int iy_upper,
const int ix_upper,
T roi_start_h,
T roi_start_w,
T bin_size_h,
T bin_size_w,
int roi_bin_grid_h,
int roi_bin_grid_w,
std::vector<PreCalc<T>>& pre_calc) {
int pre_calc_index = 0;
for (int ph = 0; ph < pooled_height; ph++) {
for (int pw = 0; pw < pooled_width; pw++) {
for (int iy = 0; iy < iy_upper; iy++) {
//计算采样点垂直坐标,按每个小网格大小进行均匀采样roi_bin_grid_h个值
const T yy = roi_start_h + ph * bin_size_h +
static_cast<T>(iy + .5f) * bin_size_h /
static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5
//计算采样点水平坐标,按每个小网格大小进行均匀采样roi_bin_grid_w个值
for (int ix = 0; ix < ix_upper; ix++) {
const T xx = roi_start_w + pw * bin_size_w +
static_cast<T>(ix + .5f) * bin_size_w /
static_cast<T>(roi_bin_grid_w);
T x = xx;
T y = yy;
// deal with: inverse elements are out of feature map boundary
//处理越界
if (y < -1.0 || y > height || x < -1.0 || x > width) {
// empty
PreCalc<T> pc;
pc.pos1 = 0;
pc.pos2 = 0;
pc.pos3 = 0;
pc.pos4 = 0;
pc.w1 = 0;
pc.w2 = 0;
pc.w3 = 0;
pc.w4 = 0;
pre_calc[pre_calc_index] = pc;
pre_calc_index += 1;
continue;
}
if (y <= 0) {
y = 0;
}
if (x <= 0) {
x = 0;
}
int y_low = (int)y;//采样点y向下取整,找其上方最近的整数坐标
int x_low = (i