from keras.layers import Dense,Convolution2D,Input,BatchNormalization,Activation,GlobalAveragePooling2D
from keras.models import Model
from keras.applications.mobilenet import DepthwiseConv2D
def trash_model(input_shape,num_classes,alpha=1):
input = Input(shape=input_shape)
x = Convolution2D(int(32*alpha),(3,3),strides=(2,2),padding='same',use_bias=False)(input)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = DepthwiseConv2D(int(32*alpha),(3,3),padding='same',use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Convolution2D(int(64*alpha),(1,1),padding='same',use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = DepthwiseConv2D(int(64*alpha),(3,3),strides=(2,2),padding='same',use_bias=False)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Convolution2D(int(128*alpha),(1,1),padding='same',use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = DepthwiseConv2D(int(128*alpha),(3,3),padding='same',use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Convolution2D(int(128*alpha),(1,1),padding='same',use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = DepthwiseConv2D(int(128*alpha),(3,3),strides=(2,2),padding='same',use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Convolution2D(int(256*alpha),(1,1),padding='same',use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = DepthwiseConv2D(int(256*alpha),(3,3),padding='same',use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Convolution2D(int(256*alpha),(1,1),padding='same',use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = DepthwiseConv2D(int(256*alpha),(3,3),strides=(2,2),padding='same',use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Convolution2D(int(512*alpha),(1,1),padding='same',use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
for i in range(5):
x = DepthwiseConv2D(int(512*alpha),(3,3),padding='same',use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Convolution2D(int(512*alpha),(1,1),padding='same',use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = DepthwiseConv2D(int(512*alpha),(3,3),strides=(2,2),padding='same',use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Convolution2D(int(1024*alpha),(1,1),padding='same',use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = DepthwiseConv2D(int(1024*alpha),(3,3),strides=(2,2),padding='same',use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Convolution2D(int(1024*alpha),(1,1),padding='same',use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = GlobalAveragePooling2D()(x)
x = Dense(1000,activation='relu')(x)
out = Dense(num_classes,activation='softmax')(x)
model = Model(input,out)
return model
上面是小主写的源代码
运行代码无报错:
Depthwise单元:
网络结构:
使用框架:keras 2.1.3
tensorflow cpu 2.4.1
本文参考了同校师兄的代码,师兄新出论文:Branch Feature Fusion Convolution Network for Remote Sensing Scene Classification,欢迎支持。地址如下。