class GlobalAveragePooling(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x.mean([2, 3])
在GlobalAveragePooling2D有如下代码:
def call(self, inputs):
if self.data_format == 'channels_last':
return backend.mean(inputs, axis=[1, 2])
else:
return backend.mean(inputs, axis=[2, 3])
可以看到如果输入数据为channels_first,即通道在前面的话,GAP便是对第二第三维度进行求平均。