Resnet实现,用通用函数实现全部的resnet类型,如resnet18,resnet34,resnet50等等

一、准备

网上对于何凯明等大佬写的Residual论文的解读已经够多了,经过一段时间的学习,我想摸索出一套适合所有resnet类型的通用函数,以便实验,故而在这篇博客中,我重点讲如何实现通用函数。

重点1:

 在上面图中,我们需要注意 F(x) + x 是在 激活函数relu之前进行的,知道这一点是为了实现卷积函数conv2D_BN时,先不进行激活。

重点2:

我们通过观察可知 只有在每种残差快的第一个块,其shortcut连接才需要 1x1 卷积,其相应的连接也是虚线,比如3处,但是有一处需要注意,即 1 处,它也是一种残差块的第一块,但是不需要1x1 卷积。其连接是实线,而类似于 2处 的 一种残差块内部的连接,也是不需要1x1 卷积,其连接也是实线。

重点3:

 为了便于建立通用的函数,适合所有的残差类型,我们需要创建一种表,来对应上述的内容。如下所示:

# 50-layer
#比如[3, [[64,(1,1)],[64,(3,3)],[256,(1,1)]], 第一参数 3 是 一种残差块的个数,
#第二个参数是此种残差块的卷积层的一些参数,其是过滤器个数filtes,kernel_size
filter_list_resnet50 = [ [3, [[64,(1,1)],[64,(3,3)],[256,(1,1)]] ],
                [4, [[128,(1,1)],[128,(3,3)],[512,(1,1)]] ],
                [6, [[256,(1,1)],[256,(3,3)],[1024,(1,1)]] ],
                [3, [[512,(1,1)],[512,(3,3)],[2048,(1,1)]] ]]
# 18-layer
filter_list_resnet18 = [ [2, [[64,(3,3)],[64,(3,3)]] ],
                        [2, [[128,(3,3)],[128,(3,3)]] ],
                       [2, [[256,(3,3)],[256,(3,3)]] ],
                       [2, [[512,(3,3)],[512,(3,3)]] ]]

 重点4:

 有时候特征图大小不变,有时候减半,其对应的padding就有可能不同。

除了1x1卷积用padding=‘valid’之外,其他的都用padding=‘same’。

二、实现

第一步,导入库

import tensorflow as tf
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt

 第二步,实现卷积归一化

def conv2D_BN(x, num_filter, kernel_size, strides=(1,1), padding='same'):
    '''
    为了方便 F(x) + x ,之后再relu激活
    故此卷积没有激活函数
    '''
    conv = keras.layers.Conv2D(filters=num_filter, kernel_size=kernel_size,
                       strides=strides, padding=padding, 
                        kernel_regularizer=keras.regularizers.l2(0.0001))(x)
    bn = keras.layers.BatchNormalization()(conv)
    return bn

第三步,实现基本残差块

从论文中,可以知晓,若特征图大小output map size不变,那么过滤器数目不变;若大小减半,则过滤器数目加倍。前者对应building_block内部,后者对应building_block之间。

由此,我们可以得到padding='same'。步长strides=1,则特征图大小不变,步长为2,则大小减半。

def building_block(x, filters_list, is_first_layers=False):
    '''
    这是一个基本残差块,适用于任何残差块类型。
    is_first_layers=True,说明此时步长strides=2,特征图大小需要减半,
    否则步长为1,特征图大小不变;
    同时也说明是shortcut是需要 1x1卷积的,即shortcut虚线部分;否则无需改变。
    filter_list: 包含若干个列表,每个列表包含一种类型的残差块,其信息如下:
                    此类残差块个数,[过滤器数目,核大小],[过滤器数目,核大小],,,
    '''
    y = x
    strides=(1,1)
    for i in range(len(filters_list)):
        if is_first_layers and i == 0:
            strides=(2,2)
        else:
            strides=(1,1)
        y = conv2D_BN(y, filters_list[i][0],kernel_size=filters_list[i][1],strides=strides)
        # short_cut
        '''
        is_first_layers为True,并且为残差块的最后一层
        此时说明需要1x1卷积,改变x即input的特征图大小,即减半,步长为2。其过滤器数目
        filters需要同最后一层即当前层的过滤器数目相同,即filters=filters_list[i][0]
        '''
        if is_first_layers and i == len(filters_list) - 1:
            x =  conv2D_BN(x, filters_list[i][0],kernel_size=(1,1),
                           strides=(2,2), padding='valid')
            break
        #若是残差块的最后一层,则先不需要激活,先进行相加操作,即残差块的输入和输出相加
        #其他情况,即残差块的内部层之间,可以直接激活
        if i == len(filters_list) - 1:
            break
        y = keras.layers.Activation('relu')(y)
    f = keras.layers.add([x, y])
    return keras.layers.Activation('relu')(y)

第四步,实现残差网络主体区域,即不同的地方

def residual_main_network(x, filter_list_resnet):
    for i in range(len(filter_list_resnet)):
        for j in range(filter_list_resnet[i][0]):
            #倘若是一种类型残差块的第一个块,即j==0,且不能是第一种残差块,因为第一种残差块
            #不需要shortcut,即 i != 0
            if j == 0 and i != 0:
                is_first_layers=True
            else:
                is_first_layers=False
            x = building_block(x, filters_list=filter_list_resnet[i][1],
                           is_first_layers=is_first_layers)
    return x

第五步,实现残差网络

def resnet(nclass,input_shape, filter_list_resnet): #nclass是输出种类数,input_shape是输入形状
    input_ = keras.layers.Input(shape=input_shape)
    conv1 = conv2D_BN(input_, 64, kernel_size=(7,7), strides=(2,2))
    conv1 = keras.layers.Activation('relu')(conv1)
    pool1 = keras.layers.MaxPool2D(pool_size=(3, 3),strides=(2, 2),padding='same')(conv1)
    
    conv2 = residual_main_network(pool1, filter_list_resnet)
    
    pool2 = keras.layers.GlobalAvgPool2D()(conv2)
    output_ = keras.layers.Dense(nclass, 'softmax')(pool2)
    
    model = keras.Model(inputs=input_,outputs=output_)
    return model

三、举例:用上述模型实现resnet18

filter_list_resnet18 = [ [2, [[64,(3,3)],[64,(3,3)]] ],
                        [2, [[128,(3,3)],[128,(3,3)]] ],
                       [2, [[256,(3,3)],[256,(3,3)]] ],
                       [2, [[512,(3,3)],[512,(3,3)]] ]]
model = resnet(10, (32,32,3), filter_list_resnet18)
model.summary()

运行结果:

Model: "model_3"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_5 (InputLayer)         [(None, 32, 32, 3)]       0         
_________________________________________________________________
conv2d_66 (Conv2D)           (None, 16, 16, 64)        9472      
_________________________________________________________________
batch_normalization_66 (Batc (None, 16, 16, 64)        256       
_________________________________________________________________
activation_55 (Activation)   (None, 16, 16, 64)        0         
_________________________________________________________________
max_pooling2d_4 (MaxPooling2 (None, 8, 8, 64)          0         
_________________________________________________________________
conv2d_67 (Conv2D)           (None, 8, 8, 64)          36928     
_________________________________________________________________
batch_normalization_67 (Batc (None, 8, 8, 64)          256       
_________________________________________________________________
activation_56 (Activation)   (None, 8, 8, 64)          0         
_________________________________________________________________
conv2d_68 (Conv2D)           (None, 8, 8, 64)          36928     
_________________________________________________________________
batch_normalization_68 (Batc (None, 8, 8, 64)          256       
_________________________________________________________________
activation_57 (Activation)   (None, 8, 8, 64)          0         
_________________________________________________________________
conv2d_69 (Conv2D)           (None, 8, 8, 64)          36928     
_________________________________________________________________
batch_normalization_69 (Batc (None, 8, 8, 64)          256       
_________________________________________________________________
activation_58 (Activation)   (None, 8, 8, 64)          0         
_________________________________________________________________
conv2d_70 (Conv2D)           (None, 8, 8, 64)          36928     
_________________________________________________________________
batch_normalization_70 (Batc (None, 8, 8, 64)          256       
_________________________________________________________________
activation_59 (Activation)   (None, 8, 8, 64)          0         
_________________________________________________________________
conv2d_71 (Conv2D)           (None, 4, 4, 128)         73856     
_________________________________________________________________
batch_normalization_71 (Batc (None, 4, 4, 128)         512       
_________________________________________________________________
activation_60 (Activation)   (None, 4, 4, 128)         0         
_________________________________________________________________
conv2d_72 (Conv2D)           (None, 4, 4, 128)         147584    
_________________________________________________________________
batch_normalization_72 (Batc (None, 4, 4, 128)         512       
_________________________________________________________________
activation_61 (Activation)   (None, 4, 4, 128)         0         
_________________________________________________________________
conv2d_74 (Conv2D)           (None, 4, 4, 128)         147584    
_________________________________________________________________
batch_normalization_74 (Batc (None, 4, 4, 128)         512       
_________________________________________________________________
activation_62 (Activation)   (None, 4, 4, 128)         0         
_________________________________________________________________
conv2d_75 (Conv2D)           (None, 4, 4, 128)         147584    
_________________________________________________________________
batch_normalization_75 (Batc (None, 4, 4, 128)         512       
_________________________________________________________________
activation_63 (Activation)   (None, 4, 4, 128)         0         
_________________________________________________________________
conv2d_76 (Conv2D)           (None, 2, 2, 256)         295168    
_________________________________________________________________
batch_normalization_76 (Batc (None, 2, 2, 256)         1024      
_________________________________________________________________
activation_64 (Activation)   (None, 2, 2, 256)         0         
_________________________________________________________________
conv2d_77 (Conv2D)           (None, 2, 2, 256)         590080    
_________________________________________________________________
batch_normalization_77 (Batc (None, 2, 2, 256)         1024      
_________________________________________________________________
activation_65 (Activation)   (None, 2, 2, 256)         0         
_________________________________________________________________
conv2d_79 (Conv2D)           (None, 2, 2, 256)         590080    
_________________________________________________________________
batch_normalization_79 (Batc (None, 2, 2, 256)         1024      
_________________________________________________________________
activation_66 (Activation)   (None, 2, 2, 256)         0         
_________________________________________________________________
conv2d_80 (Conv2D)           (None, 2, 2, 256)         590080    
_________________________________________________________________
batch_normalization_80 (Batc (None, 2, 2, 256)         1024      
_________________________________________________________________
activation_67 (Activation)   (None, 2, 2, 256)         0         
_________________________________________________________________
conv2d_81 (Conv2D)           (None, 1, 1, 512)         1180160   
_________________________________________________________________
batch_normalization_81 (Batc (None, 1, 1, 512)         2048      
_________________________________________________________________
activation_68 (Activation)   (None, 1, 1, 512)         0         
_________________________________________________________________
conv2d_82 (Conv2D)           (None, 1, 1, 512)         2359808   
_________________________________________________________________
batch_normalization_82 (Batc (None, 1, 1, 512)         2048      
_________________________________________________________________
activation_69 (Activation)   (None, 1, 1, 512)         0         
_________________________________________________________________
conv2d_84 (Conv2D)           (None, 1, 1, 512)         2359808   
_________________________________________________________________
batch_normalization_84 (Batc (None, 1, 1, 512)         2048      
_________________________________________________________________
activation_70 (Activation)   (None, 1, 1, 512)         0         
_________________________________________________________________
conv2d_85 (Conv2D)           (None, 1, 1, 512)         2359808   
_________________________________________________________________
batch_normalization_85 (Batc (None, 1, 1, 512)         2048      
_________________________________________________________________
activation_71 (Activation)   (None, 1, 1, 512)         0         
_________________________________________________________________
global_average_pooling2d_3 ( (None, 512)               0         
_________________________________________________________________
dense_3 (Dense)              (None, 10)                5130      
=================================================================
Total params: 11,019,530
Trainable params: 11,011,722
Non-trainable params: 7,808

下面用这个模型去对cifar10进行训练

第一步:编译模型参数和导入数据集并预处理

model.compile(optimizer=tf.optimizers.Adam(0.001),
             loss=tf.losses.SparseCategoricalCrossentropy(),
             metrics=['acc'])

from keras.datasets import cifar10
(x_train, y_train), (x_val, y_val) = cifar10.load_data()
x_train = x_train / 255
x_val = x_val / 255

第二步:观察数据集

print(x_train.shape)

plt.figure()
plt.imshow(x_train[0])
plt.show()

第三步:拟合数据集,训练网络

model.fit(x_train,y_train,validation_data=(x_val,y_val),epochs=10,
          batch_size=64)

运行结果如下:

Epoch 1/10
782/782 [==============================] - 1038s 1s/step - loss: 1.7900 - acc: 0.4212 - val_loss: 2.2747 - val_acc: 0.3473
Epoch 2/10
782/782 [==============================] - 1084s 1s/step - loss: 1.4167 - acc: 0.5629 - val_loss: 1.6816 - val_acc: 0.4755
Epoch 3/10
782/782 [==============================] - 1047s 1s/step - loss: 1.2337 - acc: 0.6355 - val_loss: 1.9268 - val_acc: 0.4499
Epoch 4/10
782/782 [==============================] - 1059s 1s/step - loss: 1.1222 - acc: 0.6760 - val_loss: 1.4456 - val_acc: 0.5592
Epoch 5/10
782/782 [==============================] - 1075s 1s/step - loss: 1.0435 - acc: 0.7047 - val_loss: 1.7463 - val_acc: 0.5160
Epoch 6/10
782/782 [==============================] - 1094s 1s/step - loss: 0.9957 - acc: 0.7297 - val_loss: 1.9739 - val_acc: 0.5149
Epoch 7/10
782/782 [==============================] - 1109s 1s/step - loss: 0.9553 - acc: 0.7510 - val_loss: 1.3359 - val_acc: 0.6366
Epoch 8/10
782/782 [==============================] - 1120s 1s/step - loss: 0.9221 - acc: 0.7681 - val_loss: 1.3839 - val_acc: 0.6401
Epoch 9/10
782/782 [==============================] - 1129s 1s/step - loss: 0.8882 - acc: 0.7852 - val_loss: 1.1889 - val_acc: 0.6920
Epoch 10/10
782/782 [==============================] - 1137s 1s/step - loss: 0.8584 - acc: 0.8003 - val_loss: 1.3718 - val_acc: 0.6465

由于电脑不咋地,所以一些参数没有优化,你如正则化,epochs大小,batch_size等等。

如有错误,欢迎指正‘‘‘‘‘’’’’’

  • 1
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值