DGL官方教程--Capsule network tutorial

Capsule network tutorial

[paper] [tutorial] [PyTorch code]: This new computer vision model has two key ideas. First, enhancing the feature representation in a vector form (instead of a scalar) called capsule. Second, replacing max-pooling with dynamic routing. The idea of dynamic routing is to integrate a lower level capsule to one or several higher level capsules with non-parametric message-passing. A tutorial shows how the latter can be implemented with DGL APIs.

Note:
Click here to download the full example code
Author: Jinjing Zhou, Jake Zhao, Zheng Zhang, Jinyang Li
在本教程中,您将学习如何用图形描述一种较为经典的模型。该方法提供了不同的观点。本教程介绍了如何为胶囊网络实现胶囊模型 。

Key ideas of Capsule

胶囊模型提供了两个关键思想:丰富的表示形式和动态进程。

丰富的表示 –在经典的卷积网络中,标量值表示给定功能的激活。相比之下,胶囊输出矢量。向量的长度表示出现特征的可能性。向量的方向代表特征的各种属性(例如姿势,变形,纹理等)。

图片地址:https://i.imgur.com/55Ovkdh.png

动态进程 –胶囊的输出将根据胶囊的预测与父母的预测吻合程度发送到上一层中的某些父母。这种按协议进行的动态进程概括了最大池的静态进程。
在训练过程中,路由是迭代完成的。每次迭代都会根据观察到的协议调整胶囊之间的路由权重。这种方式类似于k-means算法或竞争性学习
在本教程中,您将看到胶囊的动态路由算法如何自然地表达为图算法。该实现改编自Cedric Chee,仅替换了路由层。此版本实现了相似的速度和准确性。

Model implementation

Step 1: Setup and graph initialization

胶囊的两层之间的连通性形成有向的二分图,如下图所示。

图片地址:https://i.imgur.com/9tc6GLl.png

每个节点 j j j 与功能相关 v j v_j vj,代表其胶囊的输出。每个边都与特征相关联 b i j b_{ij} bij u ^ j ∣ i \hat{u}_{j|i} u^ji b i j b_{ij} bij确定路由权重,以及 u ^ j ∣ i \hat{u}_{j|i} u^ji代表胶囊的预测 i i i 对于 j j j
这是我们设置图并初始化节点和边要素的方法。

import torch.nn as nn
import torch as th
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import dgl


def init_graph(in_nodes, out_nodes, f_size):
    g = dgl.DGLGraph()
    all_nodes = in_nodes + out_nodes
    g.add_nodes(all_nodes)

    in_indx = list(range(in_nodes))
    out_indx = list(range(in_nodes, in_nodes + out_nodes))
    # add edges use edge broadcasting
    for u in in_indx:
        g.add_edges(u, out_indx)

    # init states
    g.ndata['v'] = th.zeros(all_nodes, f_size)
    g.edata['b'] = th.zeros(in_nodes * out_nodes, 1)
    return g

Step 2: Define message passing functions

这是Capsule进程算法的伪代码。

图片地址:https://i.imgur.com/mv1W9Rv.png

通过以下步骤 在class DGLRoutingLayer中实现伪代码行4-7 :

  1. 计算Capsule系数。
    Coefficients are the softmax over all out-edge of in-capsules. c i , j = softmax ( b i , j ) \textbf{c}_{i,j} = \text{softmax}(\textbf{b}_{i,j}) ci,j=softmax(bi,j)
  2. 计算所有Capsule中的加权总和。
    Output of a capsule is equal to the weighted sum of its in-capsules s j = ∑ i c i j u ^ j ∣ i s_j=\sum_i c_{ij}\hat{u}_{j|i} sj=iciju^ji
  3. 壁球输出。
    Squash the length of a Capsule’s output vector to range (0,1), so it can represent the probability (of some feature being present).
    v j = squash ( s j ) = ∣ ∣ s j ∣ ∣ 2 1 + ∣ ∣ s j ∣ ∣ 2 s j ∣ ∣ s j ∣ ∣ v_j=\text{squash}(s_j)=\frac{||s_j||^2}{1+||s_j||^2}\frac{s_j}{||s_j||} vj=squash(sj)=1+sj2sj2sjsj
  4. 通过协议量更新权重。
    The scalar product u ^ j ∣ i ⋅ v j \hat{u}_{j|i}\cdot v_j u^jivjcan be considered as how well capsule i i i agrees with j j j. It is used to update b i j = b i j + u ^ j ∣ i ⋅ v j b_{ij}=b_{ij}+\hat{u}_{j|i}\cdot v_j bij=bij+u^jivj
class DGLRoutingLayer(nn.Module):
    def __init__(self, in_nodes, out_nodes, f_size):
        super(DGLRoutingLayer, self).__init__()
        self.g = init_graph(in_nodes, out_nodes, f_size)
        self.in_nodes = in_nodes
        self.out_nodes = out_nodes
        self.in_indx = list(range(in_nodes))
        self.out_indx = list(range(in_nodes, in_nodes + out_nodes))

    def forward(self, u_hat, routing_num=1):
        self.g.edata['u_hat'] = u_hat

        # step 2 (line 5)
        def cap_message(edges):
            return {'m': edges.data['c'] * edges.data['u_hat']}

        self.g.register_message_func(cap_message)

        def cap_reduce(nodes):
            return {'s': th.sum(nodes.mailbox['m'], dim=1)}

        self.g.register_reduce_func(cap_reduce)

        for r in range(routing_num):
            # step 1 (line 4): normalize over out edges
            edges_b = self.g.edata['b'].view(self.in_nodes, self.out_nodes)
            self.g.edata['c'] = F.softmax(edges_b, dim=1).view(-1, 1)

            # Execute step 1 & 2
            self.g.update_all()

            # step 3 (line 6)
            self.g.nodes[self.out_indx].data['v'] = self.squash(self.g.nodes[self.out_indx].data['s'], dim=1)

            # step 4 (line 7)
            v = th.cat([self.g.nodes[self.out_indx].data['v']] * self.in_nodes, dim=0)
            self.g.edata['b'] = self.g.edata['b'] + (self.g.edata['u_hat'] * v).sum(dim=1, keepdim=True)

    @staticmethod
    def squash(s, dim=1):
        sq = th.sum(s ** 2, dim=dim, keepdim=True)
        s_norm = th.sqrt(sq)
        s = (sq / (1.0 + sq)) * (s / s_norm)
        return s

Step 3: Testing

制作一个简单的20x10Capsule层。

in_nodes = 20
out_nodes = 10
f_size = 4
u_hat = th.randn(in_nodes * out_nodes, f_size)
routing = DGLRoutingLayer(in_nodes, out_nodes, f_size)

您可以通过监视耦合系数的熵来可视化胶囊网络的行为。它们应该从高处开始然后下降,因为权重逐渐集中在较少的边缘上。

entropy_list = []
dist_list = []

for i in range(10):
    routing(u_hat)
    dist_matrix = routing.g.edata['c'].view(in_nodes, out_nodes)
    entropy = (-dist_matrix * th.log(dist_matrix)).sum(dim=1)
    entropy_list.append(entropy.data.numpy())
    dist_list.append(dist_matrix.data.numpy())

stds = np.std(entropy_list, axis=1)
means = np.mean(entropy_list, axis=1)
plt.errorbar(np.arange(len(entropy_list)), means, stds, marker='o')
plt.ylabel("Entropy of Weight Distribution")
plt.xlabel("Number of Routing")
plt.xticks(np.arange(len(entropy_list)))
plt.close()

图片地址:https://i.imgur.com/dMvu7p3.png

另外,我们还可以观察直方图的演变。

import seaborn as sns
import matplotlib.animation as animation

fig = plt.figure(dpi=150)
fig.clf()
ax = fig.subplots()


def dist_animate(i):
    ax.cla()
    sns.distplot(dist_list[i].reshape(-1), kde=False, ax=ax)
    ax.set_xlabel("Weight Distribution Histogram")
    ax.set_title("Routing: %d" % (i))


ani = animation.FuncAnimation(fig, dist_animate, frames=len(entropy_list), interval=500)
plt.close()

图片地址:https://github.com/VoVAllen/DGL_Capsule/raw/master/routing_dist.gif

您可以监视较低级别的Capsules如何逐渐附加到较高级别的Capsules之一。

import networkx as nx
from networkx.algorithms import bipartite

g = routing.g.to_networkx()
X, Y = bipartite.sets(g)
height_in = 10
height_out = height_in * 0.8
height_in_y = np.linspace(0, height_in, in_nodes)
height_out_y = np.linspace((height_in - height_out) / 2, height_out, out_nodes)
pos = dict()

fig2 = plt.figure(figsize=(8, 3), dpi=150)
fig2.clf()
ax = fig2.subplots()
pos.update((n, (i, 1)) for i, n in zip(height_in_y, X))  # put nodes from X at x=1
pos.update((n, (i, 2)) for i, n in zip(height_out_y, Y))  # put nodes from Y at x=2


def weight_animate(i):
    ax.cla()
    ax.axis('off')
    ax.set_title("Routing: %d  " % i)
    dm = dist_list[i]
    nx.draw_networkx_nodes(g, pos, nodelist=range(in_nodes), node_color='r', node_size=100, ax=ax)
    nx.draw_networkx_nodes(g, pos, nodelist=range(in_nodes, in_nodes + out_nodes), node_color='b', node_size=100, ax=ax)
    for edge in g.edges():
        nx.draw_networkx_edges(g, pos, edgelist=[edge], width=dm[edge[0], edge[1] - in_nodes] * 1.5, ax=ax)


ani2 = animation.FuncAnimation(fig2, weight_animate, frames=len(dist_list), interval=500)
plt.close()

图片地址:https://github.com/VoVAllen/DGL_Capsule/raw/master/routing_vis.gif

该可视化的完整代码在GitHub上提供 。在MNIST上训练的完整代码也在GitHub上。

脚本的总运行时间:(0分钟0.281秒)

下载脚本:2_Capsule.py

下载脚本:2_Capsule.ipynb

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值