【深度学习】基于pytorch的胶囊网络实现

参考资料:

  • 阿里云实践:https://developer.aliyun.com/article/581717
  • 动态路由机制讲解:https://www.bilibili.com/video/BV1oW411H7G1/?spm_id_from=333.337.search-card.all.click&vd_source=3b5e1109bdab0d21b23a5c46c4ed667d

  • Hinton论文:Dynamic Routing Between Capsules

代码:

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np


class CapsuleLayer(nn.Module):
    def __init__(self, num_capsules, num_route_nodes, in_channels, out_channels, routing_iters=3, batch_size=128):
        super(CapsuleLayer, self).__init__()
        self.num_route_nodes = num_route_nodes
        self.num_capsules = num_capsules
        self.routing_iters = routing_iters
        
        self.W = nn.Parameter(torch.randn(1, num_capsules, num_route_nodes, in_channels, out_channels)) 
        # ([1, 10, 1152, 8, 16])
        
    def forward(self, x):
        # ([128, 32, 6, 6, 8])
        x = x.view(x.size(0), x.size(1) * x.size(2) * x.size(3), x.size(4)) # ([128, 1152, 8])
        x = x.unsqueeze(1) # ([128, 1, 1152, 8])
        x = x.repeat(1, self.num_capsules, 1, 1) # # ([128, 10, 1152, 8])
        x = x.unsqueeze(3) # ([128, 10, 1152, 8, 1])
        u_hat = torch.matmul(x, self.W)  # ([128, 10, 1152, 1, 16])
        u_hat = u_hat.squeeze(3) # ([128, 10, 1152, 16])
        
        b = torch.zeros(x.size(0), self.num_capsules, self.num_route_nodes, 1) # ([128, 10, 1152, 1])
        if next(self.parameters()).is_cuda:
            b = b.cuda()
        
        for _ in range(self.routing_iters):
            c = F.softmax(b, dim=2)
            s = (c * u_hat).sum(dim=2, keepdim=True) # ([128, 10, 1, 16])
            v = self.squash(s) # ([128, 10, 1, 16])
            if _ < self.routing_iters - 1:
                b = b + (u_hat * v).sum(dim=-1, keepdim=True)
        
        return v.squeeze(dim=-1)
    
    def squash(self, input_tensor):
        squared_norm = (input_tensor ** 2).sum(dim=-1, keepdim=True)
        scale = squared_norm / (1 + squared_norm)
        output_tensor = scale * input_tensor / torch.sqrt(squared_norm)
        return output_tensor
    
class PrimaryCaps(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, dim_capsule):
        super(PrimaryCaps, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels * dim_capsule, kernel_size=kernel_size, stride=stride)
        self.dim_capsule = dim_capsule
    
    def forward(self, x):
        x = self.conv(x)
        batch_size = x.size(0)
        out_channels = int(x.size(1) / self.dim_capsule)
        height = x.size(2)
        width = x.size(3)
        # Reshape to [batch_size, out_channels, height, width, dim_capsule]
        # 使用 view 方法进行形状改变
        x = x.view(batch_size, out_channels, height, width, self.dim_capsule)
        return x
    
class CapsNet(nn.Module):
    def __init__(self, input_shape, n_class, routings):
        super(CapsNet, self).__init__()
        self.input_shape = input_shape
        self.n_class = n_class
        self.routings = routings
        
        self.conv1 = nn.Conv2d(in_channels=input_shape[0], out_channels=256, kernel_size=9, stride=1)
        self.primarycaps = PrimaryCaps(dim_capsule=8, in_channels=256, out_channels=32, kernel_size=9, stride=2)
        self.digitcaps = CapsuleLayer(num_capsules=n_class, num_route_nodes=32*6*6, in_channels=8, out_channels=16, routing_iters=routings)
        self.decoder = nn.Sequential(
            nn.Linear(16 * n_class, 512),
            nn.ReLU(inplace=True),
            nn.Linear(512, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, np.prod(input_shape)),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        x = F.relu(self.conv1(x)) # ([128, 256, 20, 20])
        x = self.primarycaps(x) # ([128, 32, 6, 6, 8])
        x = self.digitcaps(x) # ([128, 10, 1, 16])
        
        # Length of output capsules
        lengths = x.norm(dim=-1).squeeze(2) # ([128, 10])

        # Reconstruction
        x = x.view(x.size(0), -1) # ([128, 160])
        reconstructions = self.decoder(x)
        reconstructions = reconstructions.view(-1, *self.input_shape)
        
        return lengths, reconstructions

# Create CapsNet model
input_shape = (1, 28, 28)  # Example for MNIST
n_class = 10  # Number of classes
routings = 3  # Number of routing iterations

model = CapsNet(input_shape, n_class, routings)
print(model)

# ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++

# Load MNIST dataset
batch_size = 128
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Initialize model, optimizer, and loss function
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = CapsNet(input_shape=(1, 28, 28), n_class=10, routings=3).to(device)
optimizer = optim.Adam(model.parameters())
criterion = nn.CrossEntropyLoss()

# Training function
def train(model, train_loader, optimizer, criterion, epoch):
    model.train()
    train_loss = 0
    correct = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        lengths, reconstructions = model(data)
        classification_loss = criterion(lengths, target)
        reconstructions_loss = F.mse_loss(reconstructions, data)
        loss = classification_loss + reconstructions_loss
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        pred = torch.argmax(lengths, dim=1)
        correct += pred.eq(target).sum().item()
        
        if batch_idx % 100 == 0:
            print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} '
                  f'({100. * batch_idx / len(train_loader):.0f}%)]\tLoss(c/r): {loss.item():.6f} ({classification_loss.item():.6f} / {reconstructions_loss.item():.6f})')
    
    train_loss /= len(train_loader.dataset)
    accuracy = 100. * correct / len(train_loader.dataset)
    print(f'Train Epoch: {epoch}\tAverage loss: {train_loss:.4f}\tAccuracy: {accuracy:.2f}%')

# Testing function
def test(model, test_loader, criterion):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            lengths, reconstructions = model(data)
            loss = criterion(lengths, target) + F.mse_loss(reconstructions, data)
            test_loss += loss.item()
            pred = lengths.argmax(dim=-1)
            correct += pred.eq(target).sum().item()
    
    test_loss /= len(test_loader.dataset)
    accuracy = 100. * correct / len(test_loader.dataset)
    print(f'\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} '
          f'({accuracy:.2f}%)\n')

# Train the model
epochs = 10
for epoch in range(1, epochs + 1):
    train(model, train_loader, optimizer, criterion, epoch)
    test(model, test_loader, criterion)

这个是去掉了reconstruction的训练效果:

CapsNet(
  (conv1): Conv2d(1, 256, kernel_size=(9, 9), stride=(1, 1))
  (primarycaps): PrimaryCaps(
    (conv): Conv2d(256, 256, kernel_size=(9, 9), stride=(2, 2))
  )
  (digitcaps): CapsuleLayer()
  (decoder): Sequential(
    (0): Linear(in_features=160, out_features=512, bias=True)
    (1): ReLU(inplace=True)
    (2): Linear(in_features=512, out_features=1024, bias=True)
    (3): ReLU(inplace=True)
    (4): Linear(in_features=1024, out_features=784, bias=True)
    (5): Sigmoid()
  )
)
Train Epoch: 1 [0/60000 (0%)]	Loss: 2.301852
Train Epoch: 1 [12800/60000 (21%)]	Loss: 2.302585
Train Epoch: 1 [25600/60000 (43%)]	Loss: 2.302585
Train Epoch: 1 [38400/60000 (64%)]	Loss: 2.302585
Train Epoch: 1 [51200/60000 (85%)]	Loss: 2.302585
Train Epoch: 1	Average loss: 0.0180	Accuracy: 15.21%

Test set: Average loss: 0.0182, Accuracy: 2033/10000 (20.33%)

Train Epoch: 2 [0/60000 (0%)]	Loss: 2.302584
Train Epoch: 2 [12800/60000 (21%)]	Loss: 2.302584
Train Epoch: 2 [25600/60000 (43%)]	Loss: 2.302585
Train Epoch: 2 [38400/60000 (64%)]	Loss: 2.302584
Train Epoch: 2 [51200/60000 (85%)]	Loss: 2.302584
Train Epoch: 2	Average loss: 0.0180	Accuracy: 16.46%

Test set: Average loss: 0.0182, Accuracy: 1753/10000 (17.53%)

Train Epoch: 3 [0/60000 (0%)]	Loss: 2.302584
Train Epoch: 3 [12800/60000 (21%)]	Loss: 2.302584
Train Epoch: 3 [25600/60000 (43%)]	Loss: 2.302583
Train Epoch: 3 [38400/60000 (64%)]	Loss: 2.302582
Train Epoch: 3 [51200/60000 (85%)]	Loss: 2.302579
Train Epoch: 3	Average loss: 0.0180	Accuracy: 18.76%

Test set: Average loss: 0.0182, Accuracy: 2707/10000 (27.07%)

Train Epoch: 4 [0/60000 (0%)]	Loss: 2.302577
Train Epoch: 4 [12800/60000 (21%)]	Loss: 2.302552
Train Epoch: 4 [25600/60000 (43%)]	Loss: 2.302583
Train Epoch: 4 [38400/60000 (64%)]	Loss: 2.302581
Train Epoch: 4 [51200/60000 (85%)]	Loss: 2.302576
Train Epoch: 4	Average loss: 0.0180	Accuracy: 21.94%

Test set: Average loss: 0.0182, Accuracy: 1935/10000 (19.35%)

Train Epoch: 5 [0/60000 (0%)]	Loss: 2.302576
Train Epoch: 5 [12800/60000 (21%)]	Loss: 2.302568
Train Epoch: 5 [25600/60000 (43%)]	Loss: 2.302579
Train Epoch: 5 [38400/60000 (64%)]	Loss: 2.302577
Train Epoch: 5 [51200/60000 (85%)]	Loss: 2.302578
Train Epoch: 5	Average loss: 0.0180	Accuracy: 23.75%

Test set: Average loss: 0.0182, Accuracy: 2909/10000 (29.09%)

Train Epoch: 6 [0/60000 (0%)]	Loss: 2.302570
Train Epoch: 6 [12800/60000 (21%)]	Loss: 2.302553
Train Epoch: 6 [25600/60000 (43%)]	Loss: 2.302286
Train Epoch: 6 [38400/60000 (64%)]	Loss: 1.515582
Train Epoch: 6 [51200/60000 (85%)]	Loss: 1.519818
Train Epoch: 6	Average loss: 0.0147	Accuracy: 66.09%

Test set: Average loss: 0.0119, Accuracy: 9556/10000 (95.56%)

Train Epoch: 7 [0/60000 (0%)]	Loss: 1.531795
Train Epoch: 7 [12800/60000 (21%)]	Loss: 1.491780
Train Epoch: 7 [25600/60000 (43%)]	Loss: 1.485394
Train Epoch: 7 [38400/60000 (64%)]	Loss: 1.485206
Train Epoch: 7 [51200/60000 (85%)]	Loss: 1.483377
Train Epoch: 7	Average loss: 0.0117	Accuracy: 97.54%

Test set: Average loss: 0.0118, Accuracy: 9801/10000 (98.01%)

Train Epoch: 8 [0/60000 (0%)]	Loss: 1.492681
Train Epoch: 8 [12800/60000 (21%)]	Loss: 1.511164
Train Epoch: 8 [25600/60000 (43%)]	Loss: 1.471053
Train Epoch: 8 [38400/60000 (64%)]	Loss: 1.472341
Train Epoch: 8 [51200/60000 (85%)]	Loss: 1.496702
Train Epoch: 8	Average loss: 0.0116	Accuracy: 98.42%

Test set: Average loss: 0.0118, Accuracy: 9787/10000 (97.87%)

Train Epoch: 9 [0/60000 (0%)]	Loss: 1.476871
Train Epoch: 9 [12800/60000 (21%)]	Loss: 1.483141
Train Epoch: 9 [25600/60000 (43%)]	Loss: 1.489299
Train Epoch: 9 [38400/60000 (64%)]	Loss: 1.478493
Train Epoch: 9 [51200/60000 (85%)]	Loss: 1.472141
Train Epoch: 9	Average loss: 0.0116	Accuracy: 98.89%

Test set: Average loss: 0.0117, Accuracy: 9856/10000 (98.56%)

Train Epoch: 10 [0/60000 (0%)]	Loss: 1.468742
Train Epoch: 10 [12800/60000 (21%)]	Loss: 1.479285
Train Epoch: 10 [25600/60000 (43%)]	Loss: 1.471160
Train Epoch: 10 [38400/60000 (64%)]	Loss: 1.468553
Train Epoch: 10 [51200/60000 (85%)]	Loss: 1.478799
Train Epoch: 10	Average loss: 0.0115	Accuracy: 99.06%

Test set: Average loss: 0.0117, Accuracy: 9819/10000 (98.19%)

可以看到从第六个epoch开始,acurracy呈现断崖式增长,但是前期增长却很慢,而且训练速度非常慢。

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值