pytorch 对AdaptiveAveragePooling源码实现的理解(只分析了outputsize=[1,1]的情况)

1.前言

最近有把AdaptiveAveragePooling转成其他pooling方式的需求,要求的情况是输出为outputsize=[1,1],想知道AdaptiveAveragePooling是怎么自适应计算kernel size和stride的,于是去瞄了一眼源码
源码位置:c++源码

2.outputsize=[1,1]情况下AdaptiveAveragePooling源码的处理

  Tensor adaptive_avg_pool2d(
    at::Tensor const& input,
    IntArrayRef output_size){
    if (output_size[0] == 1 && output_size[1] == 1) {
//in this case, adaptive pooling is just computing mean over hw dimensions, which can be done more efficiently

//input的最后两个维度相乘(H*W)赋给mean_size
       int64_t mean_size = input.size(-1) * input.size(-2);
//view代表一个reshape的功能,其中-1表示一个不确定的数,就是不确定想要reshape成几行,但是很肯定要reshape成mean_size列   
//reshape后对-1维用函数mean平均    
       Tensor out = input.contiguous().view({-1, mean_size}).mean(-1);
       return input.ndimension() == 3 ? out.view({input.size(0), 1, 1}) : out.view({input.size(0), input.size(1), 1, 1});
    } else {
       return _adaptive_avg_pool2d(input, output_size);
    }
  }

也即outputsize=[1,1]情况下对H*W维度求平均即可,其他维度原样输出

3.遗留问题

之前看过一博主讨论AdaptivePooling与Max/AvgPooling相互转换
但是评论区好多人说不对,留作未来这方面有疑问时再看源码
关键代码段

  inline int start_index(int a, int b, int c) {
    return (int)std::floor((float)(a * c) / b);
  }

  inline int end_index(int a, int b, int c) {
    return (int)std::ceil((float)((a + 1) * c) / b);
  }

  template <typename scalar_t>
  static void adaptive_avg_pool2d_out_frame(
            scalar_t *input_p,
            scalar_t *output_p,
            int64_t sizeD,
            int64_t isizeH,
            int64_t isizeW,
            int64_t osizeH,
            int64_t osizeW,
            int64_t istrideD,
            int64_t istrideH,
            int64_t istrideW)
  {
    int64_t d;
  #pragma omp parallel for private(d)
    for (d = 0; d < sizeD; d++)
    {
      /* loop over output */
      int64_t oh, ow;
      for(oh = 0; oh < osizeH; oh++)
      {
        int istartH = start_index(oh, osizeH, isizeH);
        int iendH   = end_index(oh, osizeH, isizeH);
        int kH = iendH - istartH;

        for(ow = 0; ow < osizeW; ow++)
        {
          int istartW = start_index(ow, osizeW, isizeW);
          int iendW   = end_index(ow, osizeW, isizeW);
          int kW = iendW - istartW;

          /* local pointers */
          scalar_t *ip = input_p   + d*istrideD + istartH*istrideH + istartW*istrideW;
          scalar_t *op = output_p  + d*osizeH*osizeW + oh*osizeW + ow;

          /* compute local average: */
          scalar_t sum = 0;
          int ih, iw;
          for(ih = 0; ih < kH; ih++)
          {
            for(iw = 0; iw < kW; iw++)
            {
              scalar_t val = *(ip + ih*istrideH + iw*istrideW);
              sum += val;
            }
          }

          /* set output to local average */
          *op = sum / kW / kH;
        }
      }
    }
  }
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值