胶囊神经网络资源转载和Pytorch实现

本文介绍胶囊神经网络的基本原理,对比传统卷积神经网络(CNN)的不足之处,并深入探讨胶囊网络的结构,包括PrimaryCapsuleLayer与DigitCapsuleLayer的实现细节。此外,还提供使用Pytorch实现胶囊网络的代码示例。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

胶囊神经网络资源转载和Pytorch实现

github那个代码是pytorch0.4之前才能用,所以这里附上博主pytorch1.6写的关键部分实现代码
首先是primary capsule layer

class PrimaryCapsuleLayer(nn.Module):
   
   def __init__(self, in_channels=256, out_channels=32, num_caps=8, kernel_size=9, stride=2):
       super().__init__()
       self.capsules = nn.ModuleList([
           nn.Conv2d(in_channels, out_channels, kernel_size, stride) for _ in range(num_caps)
       ])
       
   def _squash(self, x, dim=-1):
       # x_norm shape [B, C*H*W, 1]
       x_norm = torch.norm(x, p=2, dim=dim, keepdim=True)  # compute norm in the dim -1, namely across all capsules
       scale = x_norm**2 / (1 + x_norm**2)
       v = scale * x / x_norm
       return v
       
   def forward(self, x):
       # each capsule is a conv layer, outputs shape => list[[B, C*H*W, 1]]
       outputs = [capsule[x].reshape([x.shape[0], -1, 1]) for capsule in self.capsules]
       # shape: [B, C*H*W, num_caps]
       outputs = torch.cat(outputs, dim=-1)
       return self._squash(outputs)

然后是digit capsule layer

class DigitCapsuleLayer(nn.Module):
    
    def _squash(self, x, dim=-1):
        # x_norm shape [B, C*H*W, 1]
        x_norm = torch.norm(x, p=2, dim=dim, keepdim=True)  # compute norm in the dim -1, namely across all capsules
        scale = x_norm**2 / (1 + x_norm**2)
        v = scale * x / x_norm
        return v
    
    def __init__(self, num_caps, num_route_nodes, in_channels, out_channels, num_iterations=3):
        super().__init__()
        self.route_weights = nn.Parameter(torch.randn(num_caps, num_route_nodes, in_channels, out_channels))
        
    def forward(self, x):
        # shape [num_caps, B, C*H*W->num_route_nodes, 1, out_c]
        prior = x[None, :, :, None, :] @ self.route_weights[:, None, :, :, :]
        # shape [num_caps, B, num_route_nodes, 1, out_c]
        logits = torch.zeros(*prior.shape)
        
        for i in range(self.num_iterations):
            # shape [num_caps, B, C*H*W->num_route_nodes, 1, out_c], detach the node!
            u_hat = prior.detach()
            # shape [num_caps, B, 1, 1, out_c]
            probs = torch.softmax(logits, dim=2)
            # shape [num_caps, B, 1, 1, out_c]
            outputs = self._squash((u_hat * probs).sum(dim=2, keepdim=True))
            # shape [num_caps, B, num_route_nodes, 1, 1]
            delta_logits = (u_hat * outputs).sum(dim=-1, keepdim=True)
            # shape [num_caps, B, num_route_nodes, 1, out_c]
            logits += delta_logits
        # after iteration, we get the correct logits
        probs = torch.softmax(logits, dim=2)
        # shape [num_caps, B, 1, 1, out_c]
        outputs = self._squash((prior * probs).sum(dim=2, keepdim=True))
        return outputs

最后是整个胶囊网络整合,这个部分就和github那个差不多,这里就没全部写上去了

class CapsuleNetwork(nn.Module):
    
    def __init__(self):
        NUM_CLASSES = 10
        super().__init__()
        # basic conv layer to extract fature maps from mnist image
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=256, kernel_size=9, stride=1)
        # primary capsule_layer
        self.primary_capsules = PrimaryCapsuleLayer(in_channels=256, out_channels=32,
                                                    num_caps=8, kernel_size=9, stride=2)
        # digit capsule
        self.digit_capsules = DigitCapsuleLayer(num_caps=NUM_CLASSES, num_route_nodes=32*6*6,
                                                in_channels=8, out_channels=16, num_iterations=3)
        
        # decoder to reconstruct the digit image
        self.decoder = nn.Sequential(
            nn.Linear(16 * NUM_CLASSES, 512),
            nn.ReLU(inplace=True),
            nn.Linear(512, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 784),  # reconstruction (28 * 28 = 784)
            nn.Sigmoid()
        )

    def forward(self, x, y=None):
        pass
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值