基于pytorch的胶囊网络minst图像分类实现

关于《Dynamic Routing Between Capsules》这篇论文的代码复现网上有很多,基本都是做图像重构的。我修改了其中一部分代码,实现了minst图像分类。
参考:基于pytorch的CapsNet代码详解.

胶囊网络结构

在这里插入图片描述
胶囊网络基本结构如下:

  • 普通卷积层conv1
  • 预胶囊层PrimaryCaps:为胶囊层做准备,运算为卷积运算。
  • 胶囊层DigitCaps:代替全连接层,输出为10个胶囊。

图像重构和图像分类不同的是,胶囊层后面还接了一个decoder层,将输出的胶囊转化为图像,所以前面的代码都是一样的。下面我们直接来看分类的实现。

package

import torch
from torch import nn
from torch.optim import Adam
import torch.nn.functional as F
from torchvision import transforms, datasets
import time

dataset

def load_mnist(path='./data', download=False, batch_size=100, shift_pixels=2):
    """
    Construct dataloaders for training and test data. Data augmentation is also done here.
    :param path: file path of the dataset
    :param download: whether to download the original data
    :param batch_size: batch size
    :param shift_pixels: maximum number of pixels to shift in each direction
    :return: train_loader, test_loader
    """
    kwargs = {
   'num_workers': 1, 'pin_memory': True}

    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST(path, train=True, download=download,
                       transform=transforms.ToTensor()),
        batch_size=batch_size, shuffle=True, **kwargs)

    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST(path, train=False, download=download,
                       transform=transforms.ToTensor()),
        batch_size=batch_size, shuffle=True, **kwargs)

    return train_loader, test_loader

model

1)Sqush激活函数
在这里插入图片描述

def squash(inputs, axis=-1):
    """
    The non-linear activation used in Capsule. It drives the length of a large vector to near 1 and small vector to 0
    :param inputs: vectors to be squashed
    :param axis: the axis to squash
    :return: a Tensor with same size as inputs
    """
    norm = torch.norm(inputs, p=2, dim=axis, keepdim=True)
    scale = norm**2 / (1 + norm**2) / (norm + 1e-8)
    return scale * inputs

此函数用来将verctor(胶囊就是vector)的长度(范数)压缩到0和1之间,axis=-1表示压缩倒数第一维。

2) PrimaryCapsule

class PrimaryCapsule(nn.Module):
    """
    Apply Conv2D with `out_channels` and then reshape to get capsules
    :param in_channels: input channels
    :param out_channels: output channels
    :param dim_caps: dimension of capsule
    :param kernel_size: kernel 
  • 6
    点赞
  • 97
    收藏
    觉得还不错? 一键收藏
  • 9
    评论
评论 9
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值