网上找了好久随机池化的Pytorch代码,都没有太合适的。就自己写了个不太优美的实现。若是在论文中使用该代码,请私信告知我用途。
from torch import nn
import torch
from torch.nn import functional as F
class StochasticPooling(nn.Module):
def __init__(self):
super().__init__()
self.avg = nn.AvgPool2d(2)
def forward(self, x):
s = 4 * F.interpolate(F.avg_pool2d(x, 2), scale_factor=2, mode='nearest')
p = x / s
if self.training:
b, c, h, w = p.shape
o = torch.zeros(b, c, h // 2, w // 2)
for i in range(h // 2):
for j in range(w // 2):
pij = p[:, :, 2 * i:2 * i + 2, 2 * j:2 * j + 2].reshape(b, c, -1)
idx = torch.distributions.Multinomial(1, pij).sample().reshape(b, c, 2, 2)
o[:, :, i, j] = x[:, :, 2 * i:2 * i + 2, 2 * j:2 * j + 2][idx == 1].reshape(b, c)
return o
else:
return 4 * F.avg_pool2d(p * x, 2)
需要注意的是,输入特征图的尺寸大小应该为2的整数倍,否则会报错,是由于下面这两行代码引起的
s = 4 * F.interpolate(F.avg_pool2d(x, 2), scale_factor=2, mode='nearest')
p = x / s