StochasticPooling 随机池化的PyTorch实现

网上找了好久随机池化的Pytorch代码,都没有太合适的。就自己写了个不太优美的实现。若是在论文中使用该代码,请私信告知我用途。

from torch import nn
import torch
from torch.nn import functional as F

class StochasticPooling(nn.Module):
    def __init__(self):
        super().__init__()
        self.avg = nn.AvgPool2d(2)

    def forward(self, x):
        s = 4 * F.interpolate(F.avg_pool2d(x, 2), scale_factor=2, mode='nearest')
        p = x / s
        if self.training:
            b, c, h, w = p.shape
            o = torch.zeros(b, c, h // 2, w // 2)
            for i in range(h // 2):
                for j in range(w // 2):
                    pij = p[:, :, 2 * i:2 * i + 2, 2 * j:2 * j + 2].reshape(b, c, -1)
                    idx = torch.distributions.Multinomial(1, pij).sample().reshape(b, c, 2, 2)
                    o[:, :, i, j] = x[:, :, 2 * i:2 * i + 2, 2 * j:2 * j + 2][idx == 1].reshape(b, c)
            return o
        else:
            return 4 * F.avg_pool2d(p * x, 2)

需要注意的是,输入特征图的尺寸大小应该为2的整数倍,否则会报错,是由于下面这两行代码引起的

        s = 4 * F.interpolate(F.avg_pool2d(x, 2), scale_factor=2, mode='nearest')
        p = x / s
随机池化是一种池化方法,它与传统的最大池化和平均池化不同,它不是简单地选取最大或平均值,而是通过随机采样来选择池化后的值。在PyTorch中,可以通过自定义一个继承自nn.Module的类来实现随机池化。下面是一个简单的随机池化PyTorch实现代码: ``` from torch import nn import torch from torch.nn import functional as F class StochasticPooling(nn.Module): def __init__(self): super().__init__() self.avg = nn.AvgPool2d(2) def forward(self, x): s = 4 * F.interpolate(F.avg_pool2d(x, 2), scale_factor=2, mode='nearest') p = x / s if self.training: b, c, h, w = p.shape o = torch.zeros(b, c, h // 2, w // 2) for i in range(h // 2): for j in range(w // 2): pij = p[:, :, 2 * i:2 * i + 2, 2 * j:2 * j + 2].reshape(b, c, -1) idx = torch.distributions.Multinomial(1, pij).sample().reshape(b, c, 2, 2) o[:, :, i, j] = x[:, :, 2 * i:2 * i + 2, 2 * j:2 * j + 2][idx == 1].reshape(b, c) return o else: return 4 * F.avg_pool2d(p * x, 2) ``` 这个实现中,我们首先定义了一个继承自nn.Module的类StochasticPooling,它包含一个AvgPool2d层和一个前向传播函数forward。在前向传播函数中,我们首先计算了一个s值,然后通过s值计算出一个p值。如果是在训练模式下,我们会遍历每个池化区域,计算出每个像素被选中的概率pij,然后通过Multinomial分布采样得到一个idx值,最后根据idx值选择出对应的像素值。如果是在测试模式下,我们则直接使用平均池化和乘法来计算池化后的值。
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值