Pytorch实现CapsuleNet

这里不讨论capsule的设计原理、优势以及特点等信息,只关注Capsule Net是如何实现的。

在这里插入图片描述

总体流程

  1. 在(28,28)的图片上进行卷积操作,得到feature map(20,20,256)
  2. concat 8个卷积得到的feature map,将其作为capsule(2048,8)
  3. 使用转移矩阵W将每个8维的capsule转换为10个16维的高级capsule(2048,10,16),再加权求和这2048个高级capsule得到DigitCaps(10,16),使用动态路由算法调整W.
  4. 将长度作为概率进行预测,并将概率最高的向量通过全连接层进行重构,分别计算分类损失和重构损失。

Conv1原图上提取低级特征

输入的mnist图片维度是(28,28),首先经过一个尺寸为(9,9)的卷积核,输出的feature map为(20,20,256)。

class ConvLayer(nn.Module):
    def __init__(self, in_channels=1, out_channels=256, kernel_size=9):
        super(ConvLayer, self).__init__()

        self.conv = nn.Conv2d(in_channels=in_channels,
                              out_channels=out_channels,
                              kernel_size=kernel_size,
                              stride=1
                              )

    def forward(self, x):
        return F.relu(self.conv(x))

PrimaryCaps生成Capsules

capsules的产生同样是使用卷积得来的。论文中通过PrimaryCaps输出的capsule一共有32个,每个capsule尺寸是(6,6,8),于是卷积核被设计成(9,9),步长为2,输出通道数为32,每个卷积核输出的feature map为(6,6,32)。一共有8个这样的卷积核,也就意味着每个卷积核都产生capsule的一个维度。然后将这8个feature map拼接起来,得到的feature map为(8,32,6,6)。capsule的个数为2048,每个capsule是一个8维向量。最后,再将每个capsule进行squash操作,将每个向量的长度控制到0-1之间。

这个让我想起capsule与CNN的概念上的区别,CNN每个神经元是一个标量,而capsule的神经元是一个矢量,从上面这个操作我们能看到这个矢量是怎么得来的。如果说每个矢量的维度为8,那就设计8个卷积,将得到的8个feature map拼接起来,就将8个标量转换成一个8维的矢量,这是什么神仙操作?

class PrimaryCaps(nn.Module):
    def __init__(self, num_capsules=8, in_channels=256, out_channels=32, kernel_size=9, num_routes=32 * 6 * 6):
        super(PrimaryCaps, self).__init__()
        self.num_routes = num_routes
        self.capsules = nn.ModuleList([
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=2, padding=0)
            for _ in range(num_capsules)])

    def forward(self,</
  • 1
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 6
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值