pytorch 胶囊网络

136 篇文章 17 订阅
55 篇文章 4 订阅
import torch.nn as nn
import torch.nn.functional as F
import torch
from torch.autograd import Variable
from torchvision import datasets, transforms
import numpy as np
from tqdm import tqdm
from matplotlib import pyplot as plt
import random
import pickle
import csv

%matplotlib inline

pytorch 版本 1.2.0

version = torch.__version__ # '1.2.0'
device = "cuda:0" if torch.cuda.is_available() else "cpu"

模型

# Model parameters

image_dim_size = 28

# Conv layer
cl_filter_size = 9
cl_num_filters = 256
cl_input_channels = 1
cl_stride = 1
cl_output_dim = int((image_dim_size - cl_filter_size + 1) / cl_stride) # added for pedagogical purpose

# Primary caps

# Primary Caps are equivalent to runnning 256 conv filters (pc_caps_dim x pc_num_caps_channels) 
# of size 9x9 (pc_filter_size), with stride 2, each one with 256 input channels (pc_input_channels). 
# This leads to a 6x6x256 output volume.
# Then each capsule comes from stacking 8 (pc_capas_dim) of these consecutive filters 
# on each pixel location. Hence, we will end with 6 x 6 x pc_num_caps_channels 
# capsules of dimension 8.

pc_input_size = cl_output_dim # added for pedagogical purpose
pc_filter_size = 9
pc_stride = 2
pc_caps_dim = 8
pc_input_channels = cl_num_filters # added for pedagogical purpose
pc_num_caps_channels = 32 
pc_output_dim = int((pc_input_size - pc_filter_size + 1) / pc_stride) # added for pedagogical purpose

# Digit caps
dc_num_caps = 10
dc_caps_dim = 16

# Routing parameters
iterations = 3

# Regularisation
reconst_loss_scale = 0.0005 * image_dim_size**2

胶囊网络

# Capsule model
class CapsModel(nn.Module):
    def __init__(self, 
                 cl_input_channels, 
                 cl_num_filters, 
                 cl_filter_size, 
                 cl_stride,
                 pc_input_channels,
                 pc_num_caps_channels,
                 pc_caps_dim,
                 pc_filter_size,
                 pc_stride,
                 image_dim_size,
                 dc_num_caps,
                 dc_caps_dim,
                 iterations,
                 reconst_loss_scale):
        
        super(CapsModel, self).__init__()
        
        self.iterations = iterations
        self.pc_caps_dim = pc_caps_dim 
        self.reconst_loss_scale = reconst_loss_scale
        
        self.conv_layer_1 = nn.Conv2d(in_channels=cl_input_channels,
                              out_channels=cl_num_filters,
                              kernel_size=cl_filter_size, 
                              stride=cl_stride)
        
        self.conv_layer_2 = nn.Conv2d(in_channels=pc_input_channels,
                                     out_channels=pc_num_caps_channels * pc_caps_dim,
                                     kernel_size=pc_filter_size,
                                     stride=pc_stride)
        
        cl_output_dim = int((image_dim_size - cl_filter_size + 1) / cl_stride)
        pc_output_dim = int((cl_output_dim - pc_filter_size + 1) / pc_stride)
        self.pc_num_caps = pc_output_dim*pc_output_dim*pc_num_caps_channels
        self.W = nn.Parameter(0.01 * torch.randn(1,
                                            self.pc_num_caps, # We have one weight matrix for each pair of capsules in
                                            dc_num_caps,      # primary capsules layer and digit caps layer.
                                            dc_caps_dim,      # Each such matrix has a dimension of 
                                            pc_caps_dim))
        
        self.reconst_loss = nn.MSELoss()
        
    def forward(self, x):
        c = F.relu(self.conv_layer_1(x))
        u = self.conv_layer_2(c)
        u_sliced = self.squash( u.permute(0,2,3,1).contiguous()\
                               .view(c.shape[0], -1, self.pc_caps_dim)).unsqueeze(-1).unsqueeze(2)
        u_hat = torch.matmul(self.W, u_sliced).squeeze(4)  # Removing last dummy dimension, not needed anymore
        v = self.routing(u_hat).squeeze(1)
        return v
    
    def routing(self, u_hat):
        b = torch.zeros_like( u_hat )  #b_ij parameters have the same dimension as u_hat
        u_hat_routing = u_hat.detach()
        for i in range(self.iterations):
            c = F.softmax(b, dim=2)   
            if i==(self.iterations-1):
                s = (c*u_hat).sum(1, keepdim=True)
            else:
                s = (c*u_hat_routing).sum(1, keepdim=True)
            v = self.squash(s)
            if i < self.iterations - 1: b = (b + (u_hat_routing*v).sum(3, keepdim=True))
        return v
            
    def squash(self, s):
        s_norm = s.norm(dim=-1, keepdim=True)
        v = s_norm / (1. + s_norm**2) * s
        return v    
    
    # Defining loss function
    def loss(self, T, v, x_true, x_reconstructed, lambda_param=0.5, m_plus=0.9, m_minus=0.1):     
        v_norm = v.norm(dim=2, keepdim=False)
        return (T*F.relu(m_plus - v_norm)**2 + lambda_param * (1-T)*F.relu(v_norm - m_minus)**2).sum(1).mean() \
                +self.reconst_loss_scale * self.reconst_loss(x_reconstructed, x_true.view(x_true.shape[0],-1))

解码器

class Decoder(nn.Module):
    def __init__(self, dc_caps_dim, dc_num_caps, image_dim_size):
        super(Decoder, self).__init__()
        self.dc_num_caps = dc_num_caps
        self.network = nn.Sequential(
            nn.Linear(dc_caps_dim * dc_num_caps, 512),
            nn.ReLU(inplace=True),
            nn.Linear(512, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, image_dim_size**2),
            nn.Sigmoid()
        )

        
    def forward(self, v, y_ohe):       
        return self.network( (y_ohe[:,:,None] * v).view(v.shape[0], -1) )
model = CapsModel(cl_input_channels, 
                 cl_num_filters, 
                 cl_filter_size, 
                 cl_stride,
                 pc_input_channels,
                 pc_num_caps_channels,
                 pc_caps_dim,
                 pc_filter_size,
                 pc_stride,
                 image_dim_size,
                 dc_num_caps,
                 dc_caps_dim, 
                 iterations,
                 reconst_loss_scale).to(device)
'''
CapsModel(
  (conv_layer_1): Conv2d(1, 256, kernel_size=(9, 9), stride=(1, 1))
  (conv_layer_2): Conv2d(256, 256, kernel_size=(9, 9), stride=(2, 2))
  (reconst_loss): MSELoss()
)
'''

decoder = Decoder(dc_caps_dim, dc_num_caps, image_dim_size,).to(device)
'''
Decoder(
  (network): Sequential(
    (0): Linear(in_features=160, out_features=512, bias=True)
    (1): ReLU(inplace=True)
    (2): Linear(in_features=512, out_features=1024, bias=True)
    (3): ReLU(inplace=True)
    (4): Linear(in_features=1024, out_features=784, bias=True)
    (5): Sigmoid()
  )
)
'''
optimiser = torch.optim.Adam(list(model.parameters()) + list(decoder.parameters()), lr=0.001)
lr_decay = torch.optim.lr_scheduler.ExponentialLR(optimiser, gamma=0.96**(1/2000.))

统计模型总参数量:胶囊网络+解码器

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

count_parameters(model) + count_parameters(decoder)
'''
8215568
'''

训练

loss_train = 10     # Initialize with arbitrary high value
num_epoch=10

losses_train = []
losses_val = []

errors_train = []
errors_val = []

min_error = np.inf
error_rate_batch = 1.

train_sample_built = False

logfile = open('log_'+version+'.csv', 'w')
logwriter = csv.DictWriter(logfile, fieldnames=['epoch', 'train_loss', 'val_loss', 'val_error'])
logwriter.writeheader()
def error_rate_calc(v_list, y_labels_list):
    v = torch.cat(v_list)
    y_labels = torch.cat(y_labels_list).cuda()
    _, y_pred = v.norm(p=2, dim=2).max(dim=1)
    return float( 1 - (y_pred == y_labels).float().mean() )

def load_mnist(path='./data', download=True, batch_size=128, shift_pixels=2):
    """
    Construct dataloaders for training and test data. Data augmentation is also done here.
    :param path: file path of the dataset
    :param download: whether to download the original data
    :param batch_size: batch size
    :param shift_pixels: maximum number of pixels to shift in each direction
    :return: train_loader, test_loader
    """
    kwargs = {'num_workers': 2, 'pin_memory': True}

    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST(path, 
                       train=True, 
                       download=download,
                       transform=transforms.Compose([transforms.RandomCrop(size=28, padding=shift_pixels),
                                                     transforms.ToTensor()])),
        batch_size=256, 
        shuffle=True, 
        **kwargs)
    
    
    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST(path, 
                       train=False, 
                       download=download,
                       transform=transforms.ToTensor()),
        batch_size=batch_size, 
        shuffle=True, 
        **kwargs)

    return train_loader, test_loader

train_loader, test_loader = load_mnist()
# Training loop
epoch = 0
while epoch < num_epoch:
    loss_train = 0
    with tqdm(total=len(train_loader)) as pbar:
        model.train()
        for x_train, y_train in train_loader:
          
            # Model training
            x_train = Variable(x_train.cuda()).to(device).float()
            y_train_ohe =  Variable(F.one_hot(y_train,10).float()).to(device)
            
            v = model(x_train) 
            
            # Calculate loss and do backward step
            loss = model.loss(y_train_ohe, v, x_train, decoder(v, y_train_ohe ) ) 
            model.zero_grad()
            decoder.zero_grad()
            loss.backward()
            optimiser.step()  
            
            # Calculates error rate on batch
            _, y_pred_batch = v.norm(p=2, dim=2).max(dim=1)
            error_rate_batch = error_rate_batch*0.9 + float( 1 - (y_pred_batch.data == y_train.cuda()).float().mean() ) * 0.1
            
            # Performance reporting 
            loss_train += loss.item()
            pbar.set_postfix(error_rate_batch=error_rate_batch)
            pbar.update(1)
            lr_decay.step()
    
    loss_train /= len(train_loader)
    
    # Calculates error rate and loss funciton on validation set
    v_val_list = []
    loss_test_list = []
    y_test_list = []
    model.eval()
    for x_test, y_test in test_loader:
        x_test = Variable(x_test.cuda()).to(device).float()
        y_test_ohe = Variable(F.one_hot(y_test,10)).to(device).float()
        
        v_val_batch = model(x_test)
        v_val_list.append(v_val_batch.data)
        loss_test_list.append( model.loss(y_test_ohe, 
                                          v_val_batch, 
                                          x_test, 
                                          decoder(v_val_batch, y_test_ohe)).item())
        y_test_list.append(y_test)
    
    # Error rate: 
    v_val = torch.cat(v_val_list)
    y_test_labels = torch.cat(y_test_list).cuda()
    _, y_pred_test = v_val.norm(p=2, dim=2).max(dim=1)
    error_rate = float( 1 - (y_pred_test == y_test_labels).float().mean() )
    
    # Calculates loss function on validation set
    loss_val = np.mean(loss_test_list)
 
    # Stores loss values for train and validation
    losses_train.append(loss_train)
    errors_train.append(error_rate_batch)
    losses_val.append(loss_val)
    errors_val.append(error_rate)
    
    # Stores model
    if error_rate<min_error: 
        torch.save(model.state_dict(), 'model_'+version+'.pickle')
        torch.save(decoder.state_dict(), 'decode_'+version+'.pickle')
        min_error = error_rate
        
    # Print and log some results
    print("epoch:{}\t loss_train:{:.4f}\t loss_val:{:.4f}\t error_rate:{:.4f}\t learning_rate:{:.3f}".format(epoch, loss_train, loss_val, error_rate, lr_decay.get_lr()[0]))
    logwriter.writerow(dict(epoch=epoch, train_loss=loss_train,
                                val_loss=loss_val, val_error=error_rate))
    
    epoch += 1

'''
100%|██████████████████████████████████████████████████████| 235/235 [00:58<00:00,  4.00it/s, error_rate_batch=0.00741]
epoch:5	 loss_train:0.0263	 loss_val:0.0216	 error_rate:0.0066	 learning_rate:0.001
100%|██████████████████████████████████████████████████████| 235/235 [00:59<00:00,  3.96it/s, error_rate_batch=0.00751]
epoch:6	 loss_train:0.0239	 loss_val:0.0198	 error_rate:0.0058	 learning_rate:0.001
100%|██████████████████████████████████████████████████████| 235/235 [00:59<00:00,  3.95it/s, error_rate_batch=0.00524]
epoch:7	 loss_train:0.0219	 loss_val:0.0187	 error_rate:0.0065	 learning_rate:0.001
100%|██████████████████████████████████████████████████████| 235/235 [00:59<00:00,  3.94it/s, error_rate_batch=0.00378]
epoch:8	 loss_train:0.0204	 loss_val:0.0203	 error_rate:0.0074	 learning_rate:0.001
100%|██████████████████████████████████████████████████████| 235/235 [00:59<00:00,  3.94it/s, error_rate_batch=0.00706]
epoch:9	 loss_train:0.0194	 loss_val:0.0179	 error_rate:0.0053	 learning_rate:0.001
'''

训练过程中的损失下降

plt.figure(figsize=(10,7))
plt.plot(errors_val, label='Error rate test')
plt.plot(errors_train, label='Error rate train')
plt.plot(losses_val, label='Loss test')
plt.plot(losses_train, label='Loss train')
plt.ylabel('Error rate / Loss')
plt.xlabel('Epoch')
plt.ylim((0,0.2))

plt.legend()
plt.show()

在这里插入图片描述
用解码器从胶囊中重建图片

for i in range(5):
    caps = model(x_train[i].unsqueeze(0))
    recon = decoder(caps,y_train_ohe[i].unsqueeze(0)).detach().cpu().view(28,28)
    plt.subplot(1,2,1)
    plt.imshow(x_train[i].view(28,28).cpu(),cmap='gray')
    plt.subplot(1,2,2)
    plt.imshow(recon, 
               cmap='gray')
    plt.show()

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值