论文连接:paper-Concurrent Spatial and Channel SE in Fully Convolutional Networks
代码下载:github-Concurrent Spatial and Channel SE in Fully Convolutional Networks
文章目录
前言
主要受到squeeze & excitation (SE) block方法的启发,该方法对于一个输入feature map,使用一个全局平均池对其空间依赖性进行分解,来对每一个通道进行逐个学习,生成一个能够表示各个通道重要性的Tensor,这个block有很好的适应性,可以穿插在任何卷积神经网络的模型之中,具有很好的泛用性,并在ILSVRC 2017 classification competition中取得优胜的成绩。作者受到这个方法的启发,提出了两种该方法的变体sSE和scSE,原始的SE方法被作者成为cSE,一共三种block,用于分别研究feature map的通道重要性(cSE),空间重要性(sSE)与两者结合的重要性(scSE)。
一、cSE-Spatial Squeeze and Channel Excitation Block
作者将原论文的模块命名为cSE
原理过程:
1. 首先将feature map通过global average pooling方法从[C, H, W]变为[C, 1, 1]
2. 然后依次使用个C/2x1x1、Cx1x1卷积进行信息的处理,得到C维的向量
3. 然后使用sigmoid函数进行归一化,得到对应的mask
4. 最后通过channel-wise相乘,得到经过信息校准过的feature map
代码实现
import keras.backend as K
from keras.layers import *
from keras.models import Model
# spatial squeeze by mean and channel excitation
def cse_block(prevlayer, prefix):
# 1. 首先将feature map通过global average pooling方法从[C, H, W]变为[C, 1, 1]
mean = Lambda(lambda xin: K.mean(xin, axis=[1, 2]))(prevlayer) # H W 求均值
# K.int_shape() Returns the shape of tensor or variable as a tuple of int or None entries
lin1 = Dense(K.int_shape(prevlayer)[
3] // 2, name=prefix + 'cse_lin1', activation='relu')(mean)
lin2 = Dense(K.int_shape(prevlayer)[
3], name=prefix + 'cse_lin2', activation='sigmoid')(lin1)
x = Multiply()([prevlayer, lin2])
return x
二、sSE-Channel Squeeze and Spatial Excitation Block
1.原理过程:
1. 直接对feature map使用1×1×1卷积, 维度由[C, H, W]变为[1, H, W]的特征图
2. 然后使用sigmoid进行激活得到spatial attention map
3. 然后用得到的结果和原来的特征图相乘
2. 代码实现
import keras.backend as K
from keras.layers import *
from keras.models import Model
def sse_block(prevlayer, prefix):
# Bug? Should be 1 here?
conv = Conv2D(K.int_shape(prevlayer)[3], (1, 1), padding="same", kernel_initializer="he_normal",
activation='sigmoid', strides=(1, 1),
name=prefix + "_conv")(prevlayer)
conv = Multiply(name=prefix + "_mul")([prevlayer, conv])
return conv
三、scSE-Spatial and Channel Squeeze & Excitation Block
原理解释
scSE,其实就是将前两者得到的结果相加
代码实现
def csse_block(x, prefix):
'''
Implementation of Concurrent Spatial and Channel ‘Squeeze & Excitation’ in Fully Convolutional Networks
https://arxiv.org/abs/1803.02579
x = csse_block(x, prefix='csse_block_{}'.format(i))
'''
cse = cse_block(x, prefix)
sse = sse_block(x, prefix)
x = Add(name=prefix + "_csse_mul")([cse, sse])
return x
scSEUNet
最后附上将scSE融入U-Net的代码
import keras.backend as K
from keras.layers import *
from keras.models import Model
# spatial squeeze by mean and channel excitation
def cse_block(prevlayer, prefix):
mean = Lambda(lambda xin: K.mean(xin, axis=[1, 2]))(prevlayer) # H W 求均值
# K.int_shape() Returns the shape of tensor or variable as a tuple of int or None entries
lin1 = Dense(K.int_shape(prevlayer)[
3] // 2, name=prefix + 'cse_lin1', activation='relu')(mean)
lin2 = Dense(K.int_shape(prevlayer)[
3], name=prefix + 'cse_lin2', activation='sigmoid')(lin1)
x = Multiply()([prevlayer, lin2])
return x
# channel squeeze and spatial excitation
def sse_block(prevlayer, prefix):
# Bug? Should be 1 here?
conv = Conv2D(K.int_shape(prevlayer)[3], (1, 1), padding="same", kernel_initializer="he_normal",
activation='sigmoid', strides=(1, 1),
name=prefix + "_conv")(prevlayer)
conv = Multiply(name=prefix + "_mul")([prevlayer, conv])
return conv
# concurrent spatial and channel squeeze and channel excitation
def csse_block(x, prefix):
'''
Implementation of Concurrent Spatial and Channel ‘Squeeze & Excitation’ in Fully Convolutional Networks
https://arxiv.org/abs/1803.02579
x = csse_block(x, prefix='csse_block_{}'.format(i))
'''
cse = cse_block(x, prefix)
sse = sse_block(x, prefix)
x = Add(name=prefix + "_csse_mul")([cse, sse])
return x
def scSEUnet(nclasses, input_height=224, input_width=224):
inputs = Input(shape=(input_height, input_width, 3))
conv1 = Conv2D(16,
3,
activation='relu',
padding='same',
kernel_initializer='he_normal')(inputs)
conv1 = BatchNormalization()(conv1)
conv1 = Conv2D(16,
3,
activation='relu',
padding='same',
kernel_initializer='he_normal')(conv1)
conv1 = BatchNormalization()(conv1)
pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
conv2 = Conv2D(32,
3,
activation='relu',
padding='same',
kernel_initializer='he_normal')(pool1)
conv2 = BatchNormalization()(conv2)
conv2 = Conv2D(32,
3,
activation='relu',
padding='same',
kernel_initializer='he_normal')(conv2)
conv2 = BatchNormalization()(conv2)
pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
conv3 = Conv2D(64,
3,
activation='relu',
padding='same',
kernel_initializer='he_normal')(pool2)
conv3 = BatchNormalization()(conv3)
conv3 = Conv2D(64,
3,
activation='relu',
padding='same',
kernel_initializer='he_normal')(conv3)
conv3 = BatchNormalization()(conv3)
pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
conv4 = Conv2D(128,
3,
activation='relu',
padding='same',
kernel_initializer='he_normal')(pool3)
conv4 = BatchNormalization()(conv4)
conv4 = Conv2D(128,
3,
activation='relu',
padding='same',
kernel_initializer='he_normal')(conv4)
conv4 = BatchNormalization()(conv4)
pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)
conv5 = Conv2D(256,
3,
activation='relu',
padding='same',
kernel_initializer='he_normal')(pool4)
conv5 = BatchNormalization()(conv5)
conv5 = Conv2D(256,
3,
activation='relu',
padding='same',
kernel_initializer='he_normal')(conv5)
conv5 = BatchNormalization()(conv5)
up6 = Conv2D(128,
2,
activation='relu',
padding='same',
kernel_initializer='he_normal')(UpSampling2D(size=(2,
2))(conv5))
up6 = BatchNormalization()(up6)
merge6 = concatenate([conv4, up6], axis=3)
conv6 = Conv2D(128,
3,
activation='relu',
padding='same',
kernel_initializer='he_normal')(merge6)
conv6 = BatchNormalization()(conv6)
conv6 = Conv2D(128,
3,
activation='relu',
padding='same',
kernel_initializer='he_normal')(conv6)
conv6 = BatchNormalization()(conv6)
up7 = Conv2D(64,
2,
activation='relu',
padding='same',
kernel_initializer='he_normal')(UpSampling2D(size=(2,
2))(conv6))
up7 = BatchNormalization()(up7)
merge7 = concatenate([conv3, up7], axis=3)
conv7 = Conv2D(64,
3,
activation='relu',
padding='same',
kernel_initializer='he_normal')(merge7)
conv7 = BatchNormalization()(conv7)
conv7 = Conv2D(64,
3,
activation='relu',
padding='same',
kernel_initializer='he_normal')(conv7)
conv7 = BatchNormalization()(conv7)
conv7 = csse_block(conv7, prefix="conv7")
up8 = Conv2D(32,
2,
activation='relu',
padding='same',
kernel_initializer='he_normal')(UpSampling2D(size=(2,
2))(conv7))
up8 = BatchNormalization()(up8)
merge8 = concatenate([conv2, up8], axis=3)
conv8 = Conv2D(32,
3,
activation='relu',
padding='same',
kernel_initializer='he_normal')(merge8)
conv8 = BatchNormalization()(conv8)
conv8 = Conv2D(32,
3,
activation='relu',
padding='same',
kernel_initializer='he_normal')(conv8)
conv8 = BatchNormalization()(conv8)
conv8 = csse_block(conv8, prefix="conv8")
up9 = Conv2D(16,
2,
activation='relu',
padding='same',
kernel_initializer='he_normal')(UpSampling2D(size=(2,
2))(conv8))
up9 = BatchNormalization()(up9)
merge9 = concatenate([conv1, up9], axis=3)
conv9 = Conv2D(16,
3,
activation='relu',
padding='same',
kernel_initializer='he_normal')(merge9)
conv9 = BatchNormalization()(conv9)
conv9 = Conv2D(16,
3,
activation='relu',
padding='same',
kernel_initializer='he_normal')(conv9)
conv9 = BatchNormalization()(conv9)
conv9 = csse_block(conv9, prefix="conv9")
conv10 = Conv2D(nclasses, (3, 3), padding='same')(conv9)
o = BatchNormalization()(conv10)
out = Activation('sigmoid')(o)
model = Model(input=inputs, output=out)
return model