基于imagenet数据集的ResNet50模型训练示例
训练前准备
数据集获取
本训练示例以imagenet数据集为例,从imagenet官方网站http://www.image-net.org/获取数据集。
模型功能描述
Resnet50为一个深度残差网络,可用于对CIFAR-10和ImageNet的1000类数据集进行分类。
目录结构
主要文件目录结构如下所示(只列出部分涉及文件,更多文件请查看获取的ResNet原始网络脚本):├── r1 // 原始模型目录
│ ├── resnet // resnet主目录
│ ├── __init__.py
│ ├── imagenet_main.py // 基于Imagenet数据集训练网络模型
│ ├── imagenet_preprocessing.py // Imagenet数据集数据预处理模块
│ ├── resnet_model.py // resnet模型文件
│ ├── resnet_run_loop.py // 数据输入处理与运行循环(训练、验证、测试)
│ ├── README.md // 项目介绍文件
│ ├── utils
│ │ ├── export.py // 数据接收函数,定义了导出的模型将会对何种格式的参数予以响应
├── utils
│ ├── flags
│ │ ├── core.py // 包含了参数定义的公共接口
│ ├── logs
│ │ ├── hooks_helper.py //自定义创建模型在测试/训练时的工具,比如每秒钟计算步数的功能、每N步或捕获CPU/GPU分析信息的功能等
│ │ ├── logger.py // 日志工具
│ ├── misc
│ │ ├── distribution_utils.py // 进行分布式运行模型的辅助函数
│ │ ├── model_helpers.py // 定义了一些能被模型调用的函数,比如控制模型是否停止
迁移说明
以下迁移过程,我们不借助迁移工具进行迁移,按照完全手工修改脚本的方式进行迁移,让大家更详细地了解当前脚本的所有迁移要点。
训练流程
Estimator简介
Estimator API属于TensorFlow的高阶API,在2018年发布的TensorFlow 1.10版本中引入,它可极大简化机器学习的编程过程。Estimator有很多优势,例如:对分布式的良好支持、简化了模型的创建工作、有利于模型开发者之间的代码分享等。
使用Estimator进行训练脚本开发的流程为:
表10-1 训练流程说明过程
描述
数据预处理
创建输入函数input_fn。
模型构建
构建模型函数model_fn。
运行配置
实例化Estimator,并传入Runconfig类对象作为运行参数。
执行训练
在Estimator上调用训练方法Estimator.train(),利用指定输入对模型进行固定步数的训练。
训练代码目录
目录结构
主要文件目录结构如下所示(只列出部分需要修改文件,更多文件请查看获取的ResNet原始网络脚本):├── r1
│ ├── resnet // resnet主目录
│ ├── imagenet_main.py // 基于Imagenet数据集训练网络模型
│ ├── imagenet_preprocessing.py // Imagenet数据集数据预处理模块
│ ├── resnet_model.py // resnet模型文件
│ ├── resnet_run_loop.py // 数据输入处理与运行循环(训练、验证、测试)
├── utils
│ ├── flags
│ │ ├── _base.py //定义模型的通用参数并设置默认值
目录文件简介
表10-2 py文件作用及功能文件名称
简介
imagenet_main.py
包含imagenet数据集数据预处理、模型构建定义、模型运行的相关函数接口。其中数据预处理部分包含get_filenames()、parse_record()、input_fn()、get_synth_input_fn(),_parse_example_proto()函数,模型部分包含ImagenetModel类、imagenet_model_fn()、run_cifar()、define_cifar_flags()函数。
imagenet_preprocessing.py
imagenet图像数据预处理接口,训练过程中包括使用提供的边界框对训练图像进行采样、将图像裁剪到采样边界框、随机翻转图像,然后调整到目标输出大小(不保留纵横比)。评估过程中使用图像大小调整(保留纵横比)和中央剪裁。
resnet_model.py
ResNet模型的实现,包括辅助构建ResNet模型的函数以及ResNet block定义函数。
resnet_run_loop.py
模型运行文件,包括输入处理和运行循环两部分,输入处理包括对输入数据进行解码和格式转换,输出image和label,还根据是否是训练过程对数据的随机化、批次、预读取等细节做出了设定;运行循环部分包括构建Estimator,然后进行训练和验证过程。总体来看,是将模型放置在具体的环境中,实现数据与误差在模型中的流动,进而利用梯度下降法更新模型参数。
数据预处理
数据预处理流程与原始模型一致,部分位置经改造以适配昇腾910 AI处理器并提升计算性能,展示的示例代码包含改动位置。
定义输入函数input_fn
这里以imagenet数据集数据预处理为例,其数据预处理部分涉及到的适配昇腾910AI处理器改造的py文件及相关函数接口介绍如下:
表10-3 数据预处理API接口
简介
位置
input_fn()
输入函数,用于处理数据集用于Estimator训练,输出真实数据。
“/official/r1/resnet/imagenet_main.py”
resnet_main()
包含数据输入、运行配置、训练以及验证的主接口。
“/official/r1/resnet/resnet_run_loop.py”
在“official/r1/resnet/imagenet_main.py”文件中增加以下头文件:from hccl.manage.api import get_rank_size
from hccl.manage.api import get_rank_id
在数据读取时,获取芯片数量以及芯片id,用于支持数据并行。
代码位置:“official/r1/resnet/imagenet_main.py”的input_fn()函数(修改部分为字体加粗部分):
definput_fn(is_training, data_dir, batch_size, num_epochs=1, dtype=tf.float32,
datasets_num_private_threads=None, parse_record_fn=parse_record,
input_context=None, drop_remainder=False, tf_data_experimental_slack=False):
"""提供训练和验证batches的函数。
参数解释:
is_training: 表示输入是否用于训练的布尔值。
data_dir: 包含输入数据集的文件路径。
batch_size: 每个batch的大小。
num_epochs: 数据集的重复数。
dtype: 图片/特征的数据类型。
datasets_num_private_threads: tf.data的专用线程数。
parse_record_fn: 解析tfrecords的入口函数。
input_context: 由'tf.distribute.Strategy'传入的'tf.distribute.InputContext'对象。
drop_remainder: 用于标示对于最后一个batch如果数据量达不到batch_size时保留还是抛弃。设置为True,则batch的维度固定。
tf_data_experimental_slack: 是否启用tf.data的'experimental_slack'选项。
Returns:
返回一个可用于迭代的数据集。
"""
# 获取文件路径
filenames = get_fil