一、keras实现的Senet block模块代码
import keras
class SeBlock(keras.layers.Layer):
def __init__(self, reduction=4,**kwargs):
super(SeBlock,self).__init__(**kwargs)
self.reduction = reduction
def build(self,input_shape):#构建layer时需要实现
#input_shape
pass
def call(self, inputs):
x = keras.layers.GlobalAveragePooling2D()(inputs)
x = keras.layers.Dense(int(x.shape[-1]) // self.reduction, use_bias=False,activation=keras.activations.relu)(x)
x = keras.layers.Dense(int(inputs.shape[-1]), use_bias=False,activation=keras.activations.hard_sigmoid)(x)
return keras.layers.Multiply()([inputs,x]) #给通道加权重
#return inputs*x
二、Senet block模块调用
outputs=SeBlock()(inputs) #创建一个SeBlock匿名对象,使用对象()调用call方法