核心网络reset
pa_avg_pool
pa_max_pool
全局平均池化
import torch.nn as nn
class GlobalPool(object):
def __init__(self, cfg):
self.pool = nn.AdaptiveAvgPool2d(1) if cfg.max_or_avg == 'avg' else nn.AdaptiveMaxPool2d(1)
def __call__(self, in_dict):
feat = self.pool(in_dict['feat'])
feat = feat.view(feat.size(0), -1)
out_dict = {'feat_list': [feat]}
return out_dict
import torch.nn.functional as F
def pa_avg_pool(in_dict):
"""Mask weighted avg pooling.
Args:
feat: pytorch tensor, with shape [N, C, H, W]
mask: pytorch tensor, with shape [N, pC, pH, pW]
Returns:
feat_list: a list (length = pC) of pytorch tensors with shape [N, C]
visible: pytorch tensor with shape [N, pC]
"""
feat = in_di