利用胶囊网络实现mnist数字分类——分辨多张数字重叠

利用胶囊网络实现mnist数字分类——分辨多张数字重叠分类

引言:通过对比普通的CNN对重叠数字的分类与CNN+Capsule对重叠数字进行分类,引证胶囊网络对位置、缩放、姿态变化较为有效。

1  首先构建胶囊网络层

from keras import activations
from keras import backend as K
from keras.engine.topology import Layer

def squash(x, axis=-1):
    s_squared_norm = K.sum(K.square(x), axis, keepdims=True) + K.epsilon()
    scale = K.sqrt(s_squared_norm)/ (0.5 + s_squared_norm)
    return scale * x


#define our own softmax function instead of K.softmax
def softmax(x, axis=-1):
    ex = K.exp(x - K.max(x, axis=axis, keepdims=True))
    return ex/K.sum(ex, axis=axis, keepdims=True)


#A Capsule Implement with Pure Keras
class Capsule(Layer):
    def __init__(self, num_capsule, dim_capsule, routings=3, share_weights=True, activation='squash', **kwargs):
        super(Capsule, self).__init__(**kwargs)
        self.num_capsule = num_capsule
        self.dim_capsule = dim_capsule
        self.routings = routings
        self.share_weights = share_weights
        if activation == 'squash':
            self.activation = squash
        else:
            self.activation = activations.get(activation)

    def build(self, input_shape):
        super(Capsule, self).build(input_shape)
        input_dim_capsule = input_shape[-1]
        if self.share_weights:
            self.W = self.add_weight(name='capsule_kernel',
                                     shape=(1, input_dim_capsule,
                                            self.num_capsule * self.dim_capsule),
                                     initializer='glorot_uniform',
                                     trainable=True)
        else:
            input_num_capsule = input_shape[-2]
            self.W = self.add_weight(name='capsule_kernel',
                                     shape=(input_num_capsule,
                                            input_dim_capsule,
                                            self.num_capsule * self.dim_capsule),
                                     initializer='glorot_uniform',
                                     trainable=True)

    def call(self, u_vecs):
        if self.share_weights:
            u_hat_vecs = K.conv1d(u_vecs, self.W)
        else:
            u_hat_vecs = K.local_conv1d(u_vecs, self.W, [1], [1])

        batch_size = K.shape(u_vecs)[0]
        input_num_capsule = K.shape(u_vecs)[1]
        u_hat_vecs = K.reshape(u_hat_vecs, (batch_size, input_num_capsule,
                                            self.num_capsule, self.dim_capsule))
        u_hat_vecs = K.permute_dimensions(u_hat_vecs, (0, 2, 1, 3))
        #final u_hat_vecs.shape = [None, num_capsule, input_num_capsule, dim_capsule]

        b = K.zeros_like(u_hat_vecs[:,:,:,0]) #shape = [None, num_capsule, input_num_capsule]
        for i in range(self.routings):
            c = softmax(b, 1)
            o = K.batch_dot(c, u_hat_vecs, [2, 2])
            if K.backend() == 'theano':
                o = K.sum(o, axis=1)
            if i < self.routings - 1:
                o = K.l2_normalize(o, -1)
                b = K.batch_dot(o, u_hat_vecs, [2, 3])
                if K.backend() == 'theano':
                    b = K.sum(b, axis=1)

        return self.activation(o)

    def compute_output_shape(self, input_shape):
        return (None, self.num_capsule, self.dim_capsule)

2 搭建网络 ,通过对比普通的CNN模型及Capsule+CNN对比

#! -*- coding: utf-8 -*-

#from Capsule_Keras import *
from keras import utils
from keras.datasets import mnist
from keras.models import Model
from keras.layers import *
from keras import backend as K


#准备训练数据
batch_size = 128
num_classes = 10
img_rows, img_cols = 28, 28
#加载数据集
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
#换one hot格式
y_train = utils.to_categorical(y_train, num_classes)
y_test = utils.to_categorical(y_test, num_classes)


#准备自定义的测试样本
#对测试集重新排序并拼接到原来测试集,就构成了新的测试集,每张图片有两个不同数字
idx = range(len(x_test))
#打乱顺序
np.random.shuffle(idx)
#传入的数组必须具有相同的形状,这里的相同的形状可以满足在拼接方向axis轴上数组间的形状一致即可
X_test = np.concatenate([x_test, x_test[idx]], 1)
#np.vstack:按垂直方向(行顺序)堆叠数组构成一个新的数组
Y_test = np.vstack([y_test.argmax(1), y_test[idx].argmax(1)]).T
X_test = X_test[Y_test[:,0] != Y_test[:,1]] #确保两个数字不一样
Y_test = Y_test[Y_test[:,0] != Y_test[:,1]]
Y_test.sort(axis=1) #排一下序,因为只比较集合,不比较顺序


#搭建普通CNN分类模型

input_image = Input(shape=(None,None,1))
cnn = Conv2D(64, (3, 3), activation='relu')(input_image)
cnn = Conv2D(64, (3, 3), activation='relu')(cnn)
cnn = AveragePooling2D((2,2))(cnn)
cnn = Conv2D(128, (3, 3), activation='relu')(cnn)
cnn = Conv2D(128, (3, 3), activation='relu')(cnn)
cnn = GlobalAveragePooling2D()(cnn)
dense = Dense(128, activation='relu')(cnn)
output = Dense(10, activation='sigmoid')(dense)

model = Model(inputs=input_image, outputs=output)
model.compile(loss=lambda y_true,y_pred: y_true*K.relu(0.9-y_pred)**2 + 0.25*(1-y_true)*K.relu(y_pred-0.1)**2,
              optimizer='adam',
              metrics=['accuracy'])

model.summary()

model.fit(x_train, y_train,
          batch_size=batch_size,
          epochs=20,
          verbose=1,
          validation_data=(x_test, y_test))

Y_pred = model.predict(X_test) #用模型进行预测
greater = np.sort(Y_pred, axis=1)[:,-2] > 0.5 #判断预测结果是否大于0.5
Y_pred = Y_pred.argsort()[:,-2:] #取最高分数的两个类别
Y_pred.sort(axis=1) #排序,因为只比较集合

acc = 1.*(np.prod(Y_pred == Y_test, axis=1)).sum()/len(X_test)
print u'CNN+Pooling,不考虑置信度的准确率为:%s'%acc
acc = 1.*(np.prod(Y_pred == Y_test, axis=1)*greater).sum()/len(X_test)
print u'CNN+Pooling,考虑置信度的准确率为:%s'%acc



#搭建CNN+Capsule分类模型
## 一个常规的 Conv2D 模型
input_image = Input(shape=(None,None,1))
cnn = Conv2D(64, (3, 3), activation='relu')(input_image)
cnn = Conv2D(64, (3, 3), activation='relu')(cnn)
cnn = AveragePooling2D((2,2))(cnn)
cnn = Conv2D(128, (3, 3), activation='relu')(cnn)
cnn = Conv2D(128, (3, 3), activation='relu')(cnn)

cnn = Reshape((-1, 128))(cnn)
capsule = Capsule(10, 16, 3, True)(cnn)
output = Lambda(lambda x: K.sqrt(K.sum(K.square(x), 2)), output_shape=(10,))(capsule)

model = Model(inputs=input_image, outputs=output)
model.compile(loss=lambda y_true,y_pred: y_true*K.relu(0.9-y_pred)**2 + 0.25*(1-y_true)*K.relu(y_pred-0.1)**2,
              optimizer='adam',
              metrics=['accuracy'])

model.summary()

model.fit(x_train, y_train,
          batch_size=batch_size,
          epochs=20,
          verbose=1,
          validation_data=(x_test, y_test))

Y_pred = model.predict(X_test) #用模型进行预测
greater = np.sort(Y_pred, axis=1)[:,-2] > 0.5 #判断预测结果是否大于0.5
Y_pred = Y_pred.argsort()[:,-2:] #取最高分数的两个类别
Y_pred.sort(axis=1) #排序,因为只比较集合

acc = 1.*(np.prod(Y_pred == Y_test, axis=1)).sum()/len(X_test)
print u'CNN+Capsule,不考虑置信度的准确率为:%s'%acc
acc = 1.*(np.prod(Y_pred == Y_test, axis=1)*greater).sum()/len(X_test)
print u'CNN+Capsule,考虑置信度的准确率为:%s'%acc

3 实验结果

Layer (type)                 Output Shape              Param #   
=================================================================
input_3 (InputLayer)         (None, None, None, 1)     0         
_________________________________________________________________
conv2d_9 (Conv2D)            (None, None, None, 64)    640       
_________________________________________________________________
conv2d_10 (Conv2D)           (None, None, None, 64)    36928     
_________________________________________________________________
average_pooling2d_3 (Average (None, None, None, 64)    0         
_________________________________________________________________
conv2d_11 (Conv2D)           (None, None, None, 128)   73856     
_________________________________________________________________
conv2d_12 (Conv2D)           (None, None, None, 128)   147584    
_________________________________________________________________
global_average_pooling2d_2 ( (None, 128)               0         
_________________________________________________________________
dense_3 (Dense)              (None, 128)               16512     
_________________________________________________________________
dense_4 (Dense)              (None, 10)                1290      
=================================================================
Total params: 276,810
Trainable params: 276,810
Non-trainable params: 0
_________________________________________________________________
Train on 60000 samples, validate on 10000 samples
Epoch 1/20
60000/60000 [==============================] - 7s 110us/step - loss: 0.0126 - acc: 0.7945 - val_loss: 0.0031 - val_acc: 0.9573
Epoch 2/20
60000/60000 [==============================] - 6s 97us/step - loss: 0.0024 - acc: 0.9658 - val_loss: 0.0024 - val_acc: 0.9647
Epoch 3/20
60000/60000 [==============================] - 6s 97us/step - loss: 0.0016 - acc: 0.9763 - val_loss: 0.0011 - val_acc: 0.9822
Epoch 4/20
60000/60000 [==============================] - 6s 98us/step - loss: 0.0012 - acc: 0.9821 - val_loss: 9.1222e-04 - val_acc: 0.9863
Epoch 5/20
60000/60000 [==============================] - 6s 100us/step - loss: 9.3207e-04 - acc: 0.9859 - val_loss: 7.4821e-04 - val_acc: 0.9892
Epoch 6/20
60000/60000 [==============================] - 6s 100us/step - loss: 8.1626e-04 - acc: 0.9885 - val_loss: 8.1356e-04 - val_acc: 0.9879
Epoch 7/20
60000/60000 [==============================] - 6s 97us/step - loss: 6.9253e-04 - acc: 0.9895 - val_loss: 7.2233e-04 - val_acc: 0.9896
Epoch 8/20
60000/60000 [==============================] - 6s 97us/step - loss: 6.4728e-04 - acc: 0.9901 - val_loss: 5.7232e-04 - val_acc: 0.9899
Epoch 9/20
60000/60000 [==============================] - 6s 97us/step - loss: 5.6721e-04 - acc: 0.9913 - val_loss: 5.6251e-04 - val_acc: 0.9917
Epoch 10/20
60000/60000 [==============================] - 6s 97us/step - loss: 5.0384e-04 - acc: 0.9923 - val_loss: 6.3515e-04 - val_acc: 0.9905
Epoch 11/20
60000/60000 [==============================] - 6s 97us/step - loss: 4.4391e-04 - acc: 0.9931 - val_loss: 7.1615e-04 - val_acc: 0.9881
Epoch 12/20
60000/60000 [==============================] - 6s 97us/step - loss: 3.9281e-04 - acc: 0.9939 - val_loss: 5.5132e-04 - val_acc: 0.9900
Epoch 13/20
60000/60000 [==============================] - 6s 97us/step - loss: 3.8931e-04 - acc: 0.9940 - val_loss: 5.1013e-04 - val_acc: 0.9918
Epoch 14/20
60000/60000 [==============================] - 6s 97us/step - loss: 3.3158e-04 - acc: 0.9951 - val_loss: 4.7458e-04 - val_acc: 0.9927
Epoch 15/20
60000/60000 [==============================] - 6s 102us/step - loss: 3.0994e-04 - acc: 0.9954 - val_loss: 4.4793e-04 - val_acc: 0.9924
Epoch 16/20
60000/60000 [==============================] - 6s 101us/step - loss: 2.8561e-04 - acc: 0.9956 - val_loss: 4.9396e-04 - val_acc: 0.9918
Epoch 17/20
60000/60000 [==============================] - 6s 97us/step - loss: 2.5102e-04 - acc: 0.9960 - val_loss: 4.7689e-04 - val_acc: 0.9927
Epoch 18/20
60000/60000 [==============================] - 6s 98us/step - loss: 2.4750e-04 - acc: 0.9960 - val_loss: 5.3580e-04 - val_acc: 0.9916
Epoch 19/20
60000/60000 [==============================] - 6s 101us/step - loss: 2.1588e-04 - acc: 0.9966 - val_loss: 5.1220e-04 - val_acc: 0.9915
Epoch 20/20
60000/60000 [==============================] - 6s 100us/step - loss: 1.9244e-04 - acc: 0.9971 - val_loss: 4.5228e-04 - val_acc: 0.9921
CNN+Pooling,不考虑置信度的准确率为:0.340964931874
CNN+Pooling,考虑置信度的准确率为:0.0813044449408
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_4 (InputLayer)         (None, None, None, 1)     0         
_________________________________________________________________
conv2d_13 (Conv2D)           (None, None, None, 64)    640       
_________________________________________________________________
conv2d_14 (Conv2D)           (None, None, None, 64)    36928     
_________________________________________________________________
average_pooling2d_4 (Average (None, None, None, 64)    0         
_________________________________________________________________
conv2d_15 (Conv2D)           (None, None, None, 128)   73856     
_________________________________________________________________
conv2d_16 (Conv2D)           (None, None, None, 128)   147584    
_________________________________________________________________
reshape_2 (Reshape)          (None, None, 128)         0         
_________________________________________________________________
capsule_2 (Capsule)          (None, 10, 16)            20480     
_________________________________________________________________
lambda_2 (Lambda)            (None, 10)                0         
=================================================================
Total params: 279,488
Trainable params: 279,488
Non-trainable params: 0
_________________________________________________________________
Train on 60000 samples, validate on 10000 samples
Epoch 1/20
60000/60000 [==============================] - 8s 133us/step - loss: 0.0105 - acc: 0.8710 - val_loss: 0.0025 - val_acc: 0.9770
Epoch 2/20
60000/60000 [==============================] - 7s 120us/step - loss: 0.0021 - acc: 0.9821 - val_loss: 0.0015 - val_acc: 0.9857
Epoch 3/20
60000/60000 [==============================] - 7s 121us/step - loss: 0.0014 - acc: 0.9885 - val_loss: 0.0012 - val_acc: 0.9906
Epoch 4/20
60000/60000 [==============================] - 7s 120us/step - loss: 0.0011 - acc: 0.9907 - val_loss: 9.9176e-04 - val_acc: 0.9907
Epoch 5/20
60000/60000 [==============================] - 7s 120us/step - loss: 8.7051e-04 - acc: 0.9927 - val_loss: 0.0010 - val_acc: 0.9922
Epoch 6/20
60000/60000 [==============================] - 7s 120us/step - loss: 7.5289e-04 - acc: 0.9936 - val_loss: 8.8222e-04 - val_acc: 0.9922
Epoch 7/20
60000/60000 [==============================] - 7s 120us/step - loss: 6.3195e-04 - acc: 0.9948 - val_loss: 8.0888e-04 - val_acc: 0.9916
Epoch 8/20
60000/60000 [==============================] - 7s 119us/step - loss: 5.6204e-04 - acc: 0.9950 - val_loss: 7.2247e-04 - val_acc: 0.9924
Epoch 9/20
60000/60000 [==============================] - 7s 119us/step - loss: 4.7763e-04 - acc: 0.9964 - val_loss: 6.9410e-04 - val_acc: 0.9925
Epoch 10/20
60000/60000 [==============================] - 7s 123us/step - loss: 4.4153e-04 - acc: 0.9967 - val_loss: 7.3162e-04 - val_acc: 0.9926
Epoch 11/20
60000/60000 [==============================] - 7s 121us/step - loss: 3.8947e-04 - acc: 0.9970 - val_loss: 7.1959e-04 - val_acc: 0.9919
Epoch 12/20
60000/60000 [==============================] - 7s 118us/step - loss: 3.3594e-04 - acc: 0.9978 - val_loss: 6.1931e-04 - val_acc: 0.9933
Epoch 13/20
60000/60000 [==============================] - 7s 118us/step - loss: 3.1502e-04 - acc: 0.9978 - val_loss: 7.0153e-04 - val_acc: 0.9922
Epoch 14/20
60000/60000 [==============================] - 7s 119us/step - loss: 2.7268e-04 - acc: 0.9983 - val_loss: 6.1045e-04 - val_acc: 0.9936
Epoch 15/20
60000/60000 [==============================] - 7s 118us/step - loss: 2.3339e-04 - acc: 0.9987 - val_loss: 6.4356e-04 - val_acc: 0.9937
Epoch 16/20
60000/60000 [==============================] - 7s 118us/step - loss: 1.9981e-04 - acc: 0.9989 - val_loss: 7.3893e-04 - val_acc: 0.9926
Epoch 17/20
60000/60000 [==============================] - 7s 118us/step - loss: 1.8829e-04 - acc: 0.9989 - val_loss: 6.3818e-04 - val_acc: 0.9923
Epoch 18/20
60000/60000 [==============================] - 7s 118us/step - loss: 1.6750e-04 - acc: 0.9992 - val_loss: 7.0149e-04 - val_acc: 0.9916
Epoch 19/20
60000/60000 [==============================] - 7s 119us/step - loss: 1.6916e-04 - acc: 0.9992 - val_loss: 7.5558e-04 - val_acc: 0.9921
Epoch 20/20
60000/60000 [==============================] - 7s 118us/step - loss: 1.4234e-04 - acc: 0.9992 - val_loss: 6.7573e-04 - val_acc: 0.9919
CNN+Capsule,不考虑置信度的准确率为:0.96694214876
CNN+Capsule,考虑置信度的准确率为:0.966495421041

4 具体github代码如下

https://github.com/leonorand/capsule_network/blob/master/capsule_network_su.ipynb

 

  • 1
    点赞
  • 47
    收藏
    觉得还不错? 一键收藏
  • 4
    评论
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值