tensorflow2制作Resnet残差网络

使用tensorflow2复现 Resnet v1&v2

关于resnet的原理、解析参考论文以及网络博客, 这里就不再复述了, 我主要是看其他框架版本的resnet进行理解, 毕竟文字描述不清楚, 细节都在代码中体现

导入相应的包

这里使用的tensorflow2

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, Dense
from tensorflow.keras.layers import Input, Flatten, AveragePooling2D

构建网络

先看看两个版本的区别

如下图片对resnet v1还是v2的区别描述的很清晰, 根据此图, 创建网络模块, 当然把两个版本都写出来.
resnet结构图片
可以看出,V1 V2 的区别就在于 (V1)Conv->BN->Activation 还是 (V2)BN->Activation->Conv
----------------------------------------------------------分割线----------------------------------------------------------

实现基础模块

如下图, 我个人认为,图中的 weight layer 的具体内容就是上述的 (V1)Conv->BN->Activation 或者 (V2)BN->Activation->Conv (注意!一个weight layer中不一定都存在这三者), 因此让我们实现这一小块的内容。Resnet的Shortcut

def weight_layer(inputs,
                 filters=16,
                 kernel_size=3,
                 strides=1,
                 activation='relu', 
                 batch_normalization=True, 
                 conv_first=True):
    '''
    这里仅仅只是实现残差中的 (V1)Conv->BN->Activation 或者 (V2)BN->Activation->Conv
    args:
        inputs(tensor):     最初输入的图片或者上一层输入的图片
        filters(int):       卷积核的数量
        kernel_size(int):   卷积核的大小
        strides(int):       卷积核移动的距离
        activation(str):    激活函数
        batch_normalization(bool): 是否使用batch_normalization
        conv_first(bool):   Resnet使用版本, resnetv1(True), resnetv2(False)
    '''
    conv = Conv2D(filters, 
                  kernel_size=kernel_size,
                  strides=strides, 
                  padding='same', 
                  kernel_initializer='he_normal',   
                  kernel_regularizer=keras.regularizers.l2(1e-4))
    
    x = inputs
    if conv_first:
        # Resnet v1
        x = conv(x)
        if batch_normalization:     # 可以随意选择是否添加BN层
            x = BatchNormalization()(x)
        if activation is not None:  # 可以随意选择是否添加激活函数层
            x = Activation(activation)(x)
    else:
        # Resnet v2
        if batch_normalization:
            x = BatchNormalization()(x)
        if activation is not None:
            x = Activation(activation)(x)
        x = conv(x)
    return x

接下来实现 Resnet V1 的网络结构,每个小块的结构都与第一张图里的一致,为了便于看代码我把网络结构图(depth=8)贴在下面:
Resnet V1

def resnet_v1(input_shape, depth, num_classes=10):
    '''
    实现Resnet V1网络

    args:
        input_shape(tube): 输入图像的shape,例如(128, 128, 3)
        depth(int):        网络的深度
        num_classes(int):  分类器输出结果的种类 
    
    return:
        model:            Resnet V1网络模型
    '''
    # depth必须是6n+2 
    # 2 是指刚开始的weight layers块以及最后的分类器
    # 6n 是因为: stack一共三个,每个stack中有两个res_blocks (具体看下面的循环)
    if (depth - 2) % 6 != 0:
        raise ValueError('网络深度必须是 6n+2!')
    num_filters = 16
    num_res_blocks = int((depth - 2) / 6)

    inputs = Input(shape=input_shape)
    x = weight_layer(inputs)

    for stack in range(3):
        for res_block in range(num_res_blocks):
            strides = 1
            if stack > 0 and res_block == 0:
                strides = 2

            y = weight_layer(x, filters=num_filters, strides=strides)
            # 这里只添加了一个卷积层以及一个BN层
            y = weight_layer(y, filters=num_filters, activation=None)
            # 纬度不对应时(因strides=2引起的),使用size=1的卷积核进行调整,让x的纬度变大以便与y相加(通过num_filters让纬度相同)
            # 这里的1*1的卷积更多的起得是整合纬度的作用
            if stack > 0 and res_block == 0:
                x = weight_layer(x, 
                filters=num_filters, 
                kernel_size=1, 
                strides=strides, 
                activation=None, 
                batch_normalization=False)

            x = keras.layers.add([x, y])
            x = Activation('relu')(x)
        num_filters *= 2
    
    # 添加分类器
    x = AveragePooling2D(pool_size=8)(x)
    y = Flatten()(x)
    outputs = Dense(num_classes, activation='softmax', kernel_initializer='he_normal')(y)

    model = keras.Model(inputs=inputs, outputs=outputs)
    return model
未完待续
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值