胶囊网络CapNet从PrimaryCaps到DigitCaps是怎么实现的?

胶囊网络CapNet从PrimaryCaps到DigitCaps是怎么实现的?

在这里插入图片描述

先看一下代码,有注释,八成能懂.

这里假设batch_size=100, 下文的维度中的100都是指的batch_size.

首先PrimaryCaps中这里是通过8个普通卷积实现的.

从PrimaryCaps这里的特征图在上图中很清楚shape为[100, 32, 6, 6, 8],但是你可以把这个8看做是一个向量, 对于特征图中的一个点就可以用一个长度为8的向量表示,一共有多少这样的点呢? 3266个.
则PrimaryCaps出来的 形状就变为 [100, 1152, 8], 这里对应[bz, num_route_nodes, len_cap].

从PrimaryCaps到DigitCaps经历了什么?

从代码中可以看到假设DigitCaps的输入为x.则x.shape=[100,1152,8], 先把它扩成[1,100,1152,1,8].

权值初始化shape为, weight.shape = [num_capsules, num_route_nodes, in_channels, out_channels], 这里也就是 [10,1152,8,16], 同样把它扩成 [10, 1, 1152, 8, 16].

那么执行矩阵相乘操作@后, 输出out.shape=[10, 100, 1152, 1, 16], 这里再对1152个节点进行加权求和以及挤压(squash)操作后,输出变为了 [10,100,1,1,16], 然后再squeeze().transpose(0,1)变为[100, 10, 16].

class CapsuleLayer(nn.Module):
    def __init__(self, num_capsules, num_route_nodes, in_channels, out_channels, kernel_size=None, stride=None,
                 num_iterations=NUM_ROUTING_ITERATIONS):
        super(CapsuleLayer, self).__init__()

        self.num_route_nodes = num_route_nodes
        self.num_iterations = num_iterations

        self.num_capsules = num_capsules

        if num_route_nodes != -1:
            self.route_weights = nn.Parameter(torch.randn(num_capsules, num_route_nodes, in_channels, out_channels))
        else:
            self.capsules = nn.ModuleList(
                [nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=0) for _ in
                 range(num_capsules)])

    def squash(self, tensor, dim=-1):
        squared_norm = (tensor ** 2).sum(dim=dim, keepdim=True)
        scale = squared_norm / (1 + squared_norm)
        return scale * tensor / torch.sqrt(squared_norm)

    def forward(self, x):
        if self.num_route_nodes != -1:  # DigitCaps x.shape=[100,1152,8]; self.route_weights.shape=[10,1152,8,16]
            priors = x[None, :, :, None, :] @ self.route_weights[:, None, :, :, :]  # @表示矩阵相乘, *表示矩阵逐元素相乘
            # priors.shape=[10,100,1152,1,16] (num_caps, bz, num_nodes, 1, len_cap)
            # print("*priors.size(): ", *priors.size())  # 10,100,1152,1,16
            logits = Variable(torch.zeros(*priors.size())).cuda()
            for i in range(self.num_iterations):  # interpret blog: https://www.jianshu.com/p/83309cdb9326
                probs = softmax(logits, dim=2)  # [10,100,1152,1,16]
                outputs = self.squash((probs * priors).sum(dim=2, keepdim=True))  # [10,100,1,1,16]

                if i != self.num_iterations - 1:
                    delta_logits = (priors * outputs).sum(dim=-1, keepdim=True)
                    logits = logits + delta_logits
        else:  # PrimaryCaps
            outputs = [capsule(x).view(x.size(0), -1, 1) for capsule in self.capsules]  # 8capsules realized by 8convs
            # outputs is a list which has 8 Tensors, shape=[100,1152,1] (bz,c*w*h,1)

            outputs = torch.cat(outputs, dim=-1)  # [100,1152,8]
            outputs = self.squash(outputs)  # [100,1152,8]

        return outputs


class CapsuleNet(nn.Module):
    def __init__(self):
        super(CapsuleNet, self).__init__()

        self.conv1 = nn.Conv2d(in_channels=1, out_channels=256, kernel_size=9, stride=1)
        self.primary_capsules = CapsuleLayer(num_capsules=8, num_route_nodes=-1, in_channels=256, out_channels=32,
                                             kernel_size=9, stride=2)
        self.digit_capsules = CapsuleLayer(num_capsules=NUM_CLASSES, num_route_nodes=32 * 6 * 6, in_channels=8,
                                           out_channels=16)

        self.decoder = nn.Sequential(
            nn.Linear(16 * NUM_CLASSES, 512),
            nn.ReLU(inplace=True),
            nn.Linear(512, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 784),
            nn.Sigmoid()
        )

    def forward(self, x, y=None):
        x = F.relu(self.conv1(x), inplace=True)
        x = self.primary_capsules(x)  # out: [100,1152,8]
        x = self.digit_capsules(x).squeeze().transpose(0, 1)  # [10,100,1,1,16] -> [10,100,16] -> [100,10,16]

        classes = (x ** 2).sum(dim=-1) ** 0.5
        classes = F.softmax(classes, dim=-1)

        if y is None:
            # In all batches, get the most active capsule.
            _, max_length_indices = classes.max(dim=1)
            y = Variable(torch.eye(NUM_CLASSES)).cuda().index_select(dim=0, index=max_length_indices.data)

        reconstructions = self.decoder((x * y[:, :, None]).view(x.size(0), -1))

        return classes, reconstructions
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值