实现ResNet使用的也是先搭建网络,然后评价迭代耗时的方法,在这里我们把 代码分成了ResNet_struct.py和ResNet_run.py,其中ResNet_struct.py用来搭建网络,ResNet_run.py用来评测网络,今天主要介绍ResNet_struct.py,我们先从搭建网络开始,导入一些需要用到的库:
接下来创建一个Block类,这个类用来配置残差学习模型的大小。初始化一个Block类需要传入3个参数:name、residual_unit和args。其中name就是这个残差学习模块的名称。residual_unit指的就是创建这个残差学习模块用到的函数,args指的就是残差学习模块的大小信息。
接下来定义一个conv2d_same()函数,创建卷积层:
接下来通过函数 residual_unit()定义残差学习单元的创建过程。我们来看一下 residual_unit()函数的参数, inputs 是输入,它是模拟的图像数据集经由 ResNet 最开始的卷积层和池化层处理得到。residual_unit()被创建