XceptionNet学习笔记
惯例先放大佬链接(https://mp.weixin.qq.com/s?__biz=MzA4MjY4NTk0NQ==&mid=2247485649&idx=1&sn=915dfceaad52ddd95eebc48b14b689da&chksm=9f80b247a8f73b51c9aae836adfd7838b95bb60dd340e50c62c74a3449757affa5a029a7cc17&scene=21#wechat_redirect)
XceptionNet其实就是在Inception V3之后,Google提出了XceptionNet,这是对Inception V3的一种改进,主要使用了深度可分离卷积来替换掉Inception V3中的卷积操作。
为了更好的说明Xception网络,我们首先需要从Inception V3来回顾一下。下面的Figure1展示了Inception V3的结构图。可以看到Inception的核心思想是将多种特征提取方式如1x1卷积,3x3卷积,5x5卷积,pooling等产生的特征图进行了concate,达到融合多种特征的效果。
然后,从Inception V3的结构联想到了一个简化的Inception结构,如Figure2所示。
再然后将Figure2的结构进行改进,就获得了Figure3所示的结构。每个3x3的卷积取前面1x1卷积后的1/3通道。
同时Figure4则为我们展示了将这一想法应用到极致,即每个通道接一个3x3卷积的结构。
所以Xception认为在卷积操作中,channel和spatial的相关性是可以解耦的,也就是不用一起运算,而普通的卷积操作中将两者混合起来进行运算,所以通过解耦这两部分可以精确控制channel和spatial上的运算,提神性能。
看到这里大家基本就知道XceptionNet的网络结构了。接下来我们比较一下这种方式和深度可分离卷积的区别。下面是深度可分离卷积的过程,可以看得出来两者的区别就是3x3卷积和1x1卷积的使用顺序的前后不一样而已,事实上作者也提到这种先后顺序并不会对最后的结果有很大的影响。
在Figure4展示的「极致的 Inception”模块」中,用于学习空间相关性的3x3卷积和用于学习通道相关性的1x1卷积「之间」如果不使用激活函数,收敛过程会更快,并且结果更好。
放上网络结构图:
代码(keras):
from keras.models import Model
from keras import layers
from keras.layers import Dense, Input, BatchNormalization, Activation
from keras.layers import Conv2D, SeparableConv2D, MaxPooling2D, GlobalAveragePooling2D, GlobalMaxPooling2D
from keras.applications.imagenet_utils import _obtain_input_shape
from keras.utils.data_utils import get_file
WEIGHTS_PATH = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.4/xception_weights_tf_dim_ordering_tf_kernels.h5'
def Xception():
# Determine proper input shape
input_shape = _obtain_input_shape(None, default_size=299, min_size=71, data_format='channels_last', include_top=False)
img_input = Input(shape=input_shape)
# Block 1
x = Conv2D(32, (3, 3), strides=(2, 2), use_bias=False)(img_input)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Conv2D(64, (3, 3), use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
residual = Conv2D(128, (1, 1), strides=(2, 2), padding='same', use_bias=False)(x)
residual = BatchNormalization()(residual)
# Block 2
x = SeparableConv2D(128, (3, 3), padding='same', use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = SeparableConv2D(128, (3, 3), padding='same', use_bias=False)(x)
x = BatchNormalization()(x)
# Block 2 Pool
x = MaxPooling2D((3, 3), strides=(2, 2), padding='same')(x)
x = layers.add([x, residual])
residual = Conv2D(256, (1, 1), strides=(2, 2), padding='same', use_bias=False)(x)
residual = BatchNormalization()(residual)
# Block 3
x = Activation('relu')(x)
x = SeparableConv2D(256, (3, 3), padding='same', use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = SeparableConv2D(256, (3, 3), padding='same', use_bias=False)(x)
x = BatchNormalization()(x)
# Block 3 Pool
x = MaxPooling2D((3, 3), strides=(2, 2), padding='same')(x)
x = layers.add([x, residual])
residual = Conv2D(728, (1, 1), strides=(2, 2), padding='same', use_bias=False)(x)
residual = BatchNormalization()(residual)
# Block 4
x = Activation('relu')(x)
x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False)(x)
x = BatchNormalization()(x)
x = MaxPooling2D((3, 3), strides=(2, 2), padding='same')(x)
x = layers.add([x, residual])
# Block 5 - 12
for i in range(8):
residual = x
x = Activation('relu')(x)
x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False)(x)
x = BatchNormalization()(x)
x = layers.add([x, residual])
residual = Conv2D(1024, (1, 1), strides=(2, 2), padding='same', use_bias=False)(x)
residual = BatchNormalization()(residual)
# Block 13
x = Activation('relu')(x)
x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = SeparableConv2D(1024, (3, 3), padding='same', use_bias=False)(x)
x = BatchNormalization()(x)
# Block 13 Pool
x = MaxPooling2D((3, 3), strides=(2, 2), padding='same')(x)
x = layers.add([x, residual])
# Block 14
x = SeparableConv2D(1536, (3, 3), padding='same', use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
# Block 14 part 2
x = SeparableConv2D(2048, (3, 3), padding='same', use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
# Fully Connected Layer
x = GlobalAveragePooling2D()(x)
x = Dense(1000, activation='softmax')(x)
inputs = img_input
# Create model
model = Model(inputs, x, name='xception')
# Download and cache the Xception weights file
weights_path = get_file('xception_weights.h5', WEIGHTS_PATH, cache_subdir='models')
# load weights
model.load_weights(weights_path)
return model