keras densenet设计

  • 针对定长文字图片的设计
from keras.models import Model
from keras.layers.core import Dense, Dropout, Activation, Reshape, Permute
from keras.layers.convolutional import Conv2D, Conv2DTranspose, ZeroPadding2D
from keras.layers.pooling import AveragePooling2D, GlobalAveragePooling2D
from keras.layers import Input, Flatten
from keras.layers.merge import concatenate
from keras.layers.normalization import BatchNormalization
from keras.regularizers import l2
from keras.layers.wrappers import TimeDistributed


def conv_block(input, growth_rate, dropout_rate=None, weight_decay=1e-4):
    x = BatchNormalization(axis=-1, epsilon=1.1e-5)(input)
    x = Activation('relu')(x)
    x = Conv2D(growth_rate, (3,3), kernel_initializer='he_normal', padding='same')(x)
    if(dropout_rate):
        x = Dropout(dropout_rate)(x)
    return x

def dense_block(x, nb_layers, nb_filter, growth_rate, droput_rate=0.2, weight_decay=1e-4):
    for i in range(nb_layers):
        cb = conv_block(x, growth_rate, droput_rate, weight_decay)
        x = concatenate([x, cb], axis=-1)
        nb_filter += growth_rate
    return x, nb_filter

def transition_block(input, nb_filter, dropout_rate=None, pooltype=1, weight_decay=1e-4):
    x = BatchNormalization(axis=-1, epsilon=1.1e-5)(input)
    x = Activation('relu')(x)
    x = Conv2D(nb_filter, (1, 1), kernel_initializer='he_normal', padding='same', use_bias=False,
               kernel_regularizer=l2(weight_decay))(x)

    if(dropout_rate):
        x = Dropout(dropout_rate)(x)

    if(pooltype == 2):
        x = AveragePooling2D((2, 2), strides=(2, 2))(x)
    elif(pooltype == 1):
        x = ZeroPadding2D(padding = (0, 1))(x)
        x = AveragePooling2D((2, 2), strides=(2, 1))(x)
    elif(pooltype == 3):
        x = AveragePooling2D((2, 2), strides=(2, 1))(x)
    return x, nb_filter

def dense_cnn(input, nclass):

    _dropout_rate = 0.2 
    _weight_decay = 1e-4

    _nb_filter = 64
    # conv 64 5*5 s=2
    x = Conv2D(_nb_filter, (5, 5), strides=(2, 2), kernel_initializer='he_normal', padding='same',
               use_bias=False, kernel_regularizer=l2(_weight_decay))(input)
   
    # 64 + 8 * 8 = 128
    x, _nb_filter = dense_block(x, 8, _nb_filter, 8, None, _weight_decay)
    # 128
    x, _nb_filter = transition_block(x, 128, _dropout_rate, 2, _weight_decay)

    # 128 + 8 * 8 = 192
    x, _nb_filter = dense_block(x, 8, _nb_filter, 8, None, _weight_decay)
    # 192 -> 128
    x, _nb_filter = transition_block(x, 128, _dropout_rate, 2, _weight_decay)

    # 128 + 8 * 8 = 192
    x, _nb_filter = dense_block(x, 8, _nb_filter, 8, None, _weight_decay)

    x = BatchNormalization(axis=-1, epsilon=1.1e-5)(x)
    x = Activation('relu')(x)
    x = Permute((2, 1, 3), name='permute')(x)
    x = TimeDistributed(Flatten(), name='flatten')(x)
    y_pred = Dense(nclass, name='out', activation='softmax')(x)

    # basemodel = Model(inputs=input, outputs=y_pred)
    # basemodel.summary()

    return y_pred

def dense_blstm(input):
    pass

input = Input(shape=(32, 280, 1), name='the_input')
dense_cnn(input, 7000)

说明

  • x = Permute((2, 1, 3), name=‘permute’)(x) ,操作前样本维度(None,4,35,192),操作后样本维度(None,35,4,192),其中样本None理解为样本个数或者批次,192是图片通道数,中间两个4,35是图片高和宽
  • x = TimeDistributed(Flatten(), name=‘flatten’)(x),将(None,35,4,192)作为输入,操作后变为(None,35,768),其中768=192*4
  • y_pred = Dense(nclass, name=‘out’, activation=‘softmax’)(x) ,操作后维度是(?, 35, 7000)

网络结构

  • 第一个层是卷积操作,卷积输出通道64,卷积模版为5*5大小,步长沿x方向和y方向分别为2和2,随着接着第二个dense block层。

  • 第二大层dense block,分为8小层,每层又进行BN->Activation->Conv2D->Dropout操作和沿着通道拼接操作,其中保持通道数等于growth_rate,_nb_filter = nb_filter + growth_rate8 = 64 + 88 = 128,其它参数也不变。

  • 第三层是transition block层,里面有4个操作,即BN->Activation->Conv2D->AveragePooling2D,卷积输出通道数设置为_nb_filter = 128。

  • 第四层是dense block,_nb_filter = nb_filter + growth_rate8 = 128 + 88 = 192。每层又进行BN->Activation->Conv2D->Dropout操作和沿着通道拼接操作,输出数据通道是192。

  • 第五层transition block层,里面有4个操作,即BN->Activation->Conv2D->AveragePooling2D,卷积输出通道数依然设置为_nb_filter = 128。

  • 第六层是dense block层,每层又进行BN->Activation->Conv2D->Dropout操作和沿着通道拼接操作,_nb_filter = nb_filter + growth_rate8 = 128 + 88 = 192。输出数据通道是192,_nb_filter 增长为192。

  • 第七层是rnn层,分别是BN->Activation->Permute->TimeDistributed->Dense操作。

总结

  • 以上可以领悟到densenet拼接和通道保持技巧
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值