TF2实现轻量级语义分割网络Mobile-SegNet

"""
Created on 2021/3/15 9:58.
@Author: haifei
"""
# https://zhuanlan.zhihu.com/p/136657292
# https://blog.csdn.net/weixin_44791964/article/details/102979289


from tensorflow.keras.layers import Input, Conv2D, BatchNormalization, Activation
from tensorflow.keras.layers import UpSampling2D, ZeroPadding2D, DepthwiseConv2D
from tensorflow.keras import Model


def _conv_block(inputs, filters, alpha, kernel=(3, 3), strides=(1, 1)):
    filters = int(filters * alpha)

    # conv + bn + relu <-- yolo v3 darknet
    x = ZeroPadding2D(padding=(1, 1))(inputs)
    x = Conv2D(filters, kernel, padding='valid', strides=strides)(x)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)

    return x

"""
https://blog.csdn.net/m0_37617773/article/details/105988668
深度可分离卷积 SeparableConv2D与DepthwiseConv2D的区别
    简单来说,SeparableConv2D是DepthwiseConv2D的升级版。通常来说深度可分离卷积分为两步,
    也就是在depplabv3+中,经常使用的方法。
        第一步:depthwise convolution是在每个通道上独自的进行空间卷积
        第二步:pointwise convolution是利用1x1卷积核组合前面depthwise convolution得到的特征
    而DepthwiseConv2D只实现了第一步, SeparableConv2D直接实现了两步。
"""

def _depthwise_conv_block(inputs, pointwise_conv_filters, alpha, depth_multiplier=1, strides=(1, 1)):
    pointwise_conv_filters = int(pointwise_conv_filters * alpha)

    x = ZeroPadding2D((1, 1))(inputs)
    x = DepthwiseConv2D((3, 3), padding='valid', depth_multiplier=depth_multiplier, strides=strides)(x)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)

    x = Conv2D(pointwise_conv_filters, (1, 1), padding='same', strides=(1, 1))(x)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)

    return x


def get_mobilenet_encoder(inputs):  # mobilenetv1=mobilenets
    alpha = 1.0
    depth_multiplier = 1

    x = _conv_block(inputs, 32, alpha, strides=(2, 2))  # 下采样
    x = _depthwise_conv_block(x, 64, alpha, depth_multiplier)
    f1 = x

    x = _depthwise_conv_block(x, 128, alpha, depth_multiplier, strides=(2, 2))  # 下采样
    x = _depthwise_conv_block(x, 128, alpha, depth_multiplier)
    f2 = x

    x = _depthwise_conv_block(x, 256, alpha, depth_multiplier, strides=(2, 2))  # 下采样
    x = _depthwise_conv_block(x, 256, alpha, depth_multiplier)
    f3 = x

    x = _depthwise_conv_block(x, 512, alpha, depth_multiplier, strides=(2, 2))  # 下采样
    x = _depthwise_conv_block(x, 512, alpha, depth_multiplier)
    x = _depthwise_conv_block(x, 512, alpha, depth_multiplier)
    x = _depthwise_conv_block(x, 512, alpha, depth_multiplier)
    x = _depthwise_conv_block(x, 512, alpha, depth_multiplier)
    x = _depthwise_conv_block(x, 512, alpha, depth_multiplier)
    f4 = x

    x = _depthwise_conv_block(x, 1024, alpha, depth_multiplier, strides=(2, 2))
    x = _depthwise_conv_block(x, 1024, alpha, depth_multiplier)
    f5 = x

    return [f1, f2, f3, f4, f5]


'''
def get_segnet_decoder(feature):
    #
    x = UpSampling2D(size=(2, 2))(feature)
    x = Conv2D(512, (3, 3), strides=1, padding='same', activation='relu')(x)
    x = BatchNormalization()(x)
    #
    x = UpSampling2D(size=(2, 2))(x)
    x = Conv2D(512, (3, 3), strides=1, padding='same', activation='relu')(x)
    x = BatchNormalization()(x)
    #
    x = UpSampling2D(size=(2, 2))(x)
    x = Conv2D(256, (3, 3), strides=1, padding='same', activation='relu')(x)
    x = BatchNormalization()(x)
    #
    x = UpSampling2D(size=(2, 2))(x)
    x = Conv2D(128, (3, 3), strides=1, padding='same', activation='relu')(x)
    x = BatchNormalization()(x)
    #
    x = UpSampling2D(size=(2, 2))(x)
    x = Conv2D(64, (3, 3), strides=1, padding='same', activation='relu')(x)
    x = BatchNormalization()(x)
    #
    return x
'''


def get_segnet_decoder(feature):
    #
    x = Conv2D(512, (3, 3), strides=1, padding='same', activation='relu')(feature)
    x = BatchNormalization()(x)
    #
    x = UpSampling2D(size=(2, 2))(x)
    x = Conv2D(256, (3, 3), strides=1, padding='same', activation='relu')(x)
    x = BatchNormalization()(x)
    #
    x = UpSampling2D(size=(2, 2))(x)
    x = Conv2D(128, (3, 3), strides=1, padding='same', activation='relu')(x)
    x = BatchNormalization()(x)
    #
    x = UpSampling2D(size=(2, 2))(x)
    x = Conv2D(128, (3, 3), strides=1, padding='same', activation='relu')(x)
    x = BatchNormalization()(x)
    #
    x = UpSampling2D(size=(2, 2))(x)
    x = Conv2D(64, (3, 3), strides=1, padding='same', activation='relu')(x)
    x = BatchNormalization()(x)
    #
    return x


def build_model(tif_size, bands, class_num):
    from pathlib import Path
    import sys
    print('===== %s =====' % Path(__file__).name)
    print('===== %s =====' % sys._getframe().f_code.co_name)

    # 输入
    inputs = Input(shape=(tif_size, tif_size, bands))
    # 编码器
    levels = get_mobilenet_encoder(inputs)
    # 解码器
    x = get_segnet_decoder(feature=levels[3])
    # 输出
    x = Conv2D(class_num, (1, 1), strides=1, padding='same', activation='softmax')(x)

    mymodel = Model(inputs, x)
    return mymodel


# model = build_model(256, 3, 2)
# model.summary()
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值