手写代码复刻MobileNet mobilenet网络(完整代码)

from keras.layers import Dense,Convolution2D,Input,BatchNormalization,Activation,GlobalAveragePooling2D
from keras.models import Model
from keras.applications.mobilenet import DepthwiseConv2D


def trash_model(input_shape,num_classes,alpha=1):

    input = Input(shape=input_shape)
    x = Convolution2D(int(32*alpha),(3,3),strides=(2,2),padding='same',use_bias=False)(input)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    x = DepthwiseConv2D(int(32*alpha),(3,3),padding='same',use_bias=False)(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = Convolution2D(int(64*alpha),(1,1),padding='same',use_bias=False)(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    x = DepthwiseConv2D(int(64*alpha),(3,3),strides=(2,2),padding='same',use_bias=False)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = Convolution2D(int(128*alpha),(1,1),padding='same',use_bias=False)(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    x = DepthwiseConv2D(int(128*alpha),(3,3),padding='same',use_bias=False)(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = Convolution2D(int(128*alpha),(1,1),padding='same',use_bias=False)(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    x = DepthwiseConv2D(int(128*alpha),(3,3),strides=(2,2),padding='same',use_bias=False)(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = Convolution2D(int(256*alpha),(1,1),padding='same',use_bias=False)(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    x = DepthwiseConv2D(int(256*alpha),(3,3),padding='same',use_bias=False)(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = Convolution2D(int(256*alpha),(1,1),padding='same',use_bias=False)(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    x = DepthwiseConv2D(int(256*alpha),(3,3),strides=(2,2),padding='same',use_bias=False)(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = Convolution2D(int(512*alpha),(1,1),padding='same',use_bias=False)(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    for i in range(5):
         x = DepthwiseConv2D(int(512*alpha),(3,3),padding='same',use_bias=False)(x)
         x = BatchNormalization()(x)
         x = Activation('relu')(x)
         x = Convolution2D(int(512*alpha),(1,1),padding='same',use_bias=False)(x)
         x = BatchNormalization()(x)
         x = Activation('relu')(x)

    x = DepthwiseConv2D(int(512*alpha),(3,3),strides=(2,2),padding='same',use_bias=False)(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = Convolution2D(int(1024*alpha),(1,1),padding='same',use_bias=False)(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    x = DepthwiseConv2D(int(1024*alpha),(3,3),strides=(2,2),padding='same',use_bias=False)(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = Convolution2D(int(1024*alpha),(1,1),padding='same',use_bias=False)(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    x = GlobalAveragePooling2D()(x)
    x = Dense(1000,activation='relu')(x)
    out = Dense(num_classes,activation='softmax')(x)
    model = Model(input,out)
    return model

上面是小主写的源代码

运行代码无报错:

Depthwise单元:

网络结构:

原论文网址[1704.04861] MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications (arxiv.org)

使用框架:keras 2.1.3

                  tensorflow cpu 2.4.1

本文参考了同校师兄的代码,师兄新出论文:Branch Feature Fusion Convolution Network for Remote Sensing Scene Classification,欢迎支持。地址如下。

https://ieeexplore.ieee.org/document/9172096

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值