forward和backward是相反的过程
在forwar的时候传入的是bottom_data和bottom_rois,输出是top_data和argmax_data;就是任意大小特征图到固定尺度特征图的过程
在做backward的时候输入是top_diff和bottom_rois和argmax_data,输出是bottom_diff;就是固定尺度大小的特征图到任意尺度的图像过程,当前向的时候argmax_data记录过当前下标的时候才会梯度回传
extern "C"
__global__ void roi_pooling2d_backward_grad_input_kernel(
const ${Dtype}* const top_diff, const ${Dtype_ind}* const argmax_data,
const ${Dtype}* bottom_rois, ${Dtype}* const bottom_diff) {
CUDA_KERNEL_LOOP(index, ${nthreads}) {
// 这个地方的index和forward的不一样,它是在特征图上的
int w = index % ${width};
int h = (index / ${width}) % ${height};
int c = (index / (${width} * ${height})) % ${channels};
int num = index / (${width} * ${height} * ${channels});
float gradient = 0;
// Accumulate gradient over all ROIs that pooled this element
for (int roi_n = 0; roi_n < ${num_rois}; ++roi_n) {
// Skip if ROI's batch index doesn't match num
if (num != static_cast<int>(bottom_rois[roi_n * 5])) {
continue;
}
int roi_start_w = round(bottom_rois[roi_n * 5 + 1]
* ${spatial_scale});
int roi_start_h = round(bottom_rois[roi_n * 5 + 2]
* ${spatial_scale});
int roi_end_w = round(bottom_rois[roi_n * 5 + 3]
* ${spatial_scale});
int roi_end_h = round(bottom_rois[roi_n * 5 + 4]
* ${spatial_scale});
// Skip if ROI doesn't include (h, w)
const bool in_roi = (w >= roi_start_w && w <= roi_end_w &&
h >= roi_start_h && h <= roi_end_h);
if (!in_roi) {
continue;
}
int offset = (roi_n * ${channels} + c) * ${pooled_height}
* ${pooled_height};
// Compute feasible set of pooled units that could have pooled
// this bottom unit
// Force malformed ROIs to be 1x1
int roi_width = max(roi_end_w - roi_start_w + 1, 1);
int roi_height = max(roi_end_h - roi_start_h + 1, 1);
float bin_size_h = static_cast<float>(roi_height)
/ static_cast<float>(${pooled_height});
float bin_size_w = static_cast<float>(roi_width)
/ static_cast<float>(${pooled_height});
int phstart = floor(static_cast<float>(h - roi_start_h)
/ bin_size_h);
int phend = ceil(static_cast<float>(h - roi_start_h + 1)
/ bin_size_h);
int pwstart = floor(static_cast<float>(w - roi_start_w)
/ bin_size_w);
int pwend = ceil(static_cast<float>(w - roi_start_w + 1)
/ bin_size_w);
phstart = min(max(phstart, 0), ${pooled_height});
phend = min(max(phend, 0), ${pooled_height});
pwstart = min(max(pwstart, 0), ${pooled_height});
pwend = min(max(pwend, 0), ${pooled_height});
for (int ph = phstart; ph < phend; ++ph) {
for (int pw = pwstart; pw < pwend; ++pw) {
int index_ = ph * ${pooled_height} + pw + offset;
if (argmax_data[index_] == (h * ${width} + w)) {
gradient += top_diff[index_];
}
}
}
}
bottom_diff[index] = gradient;
}
}