本文主要介绍谷歌官方在Github TensorFlow中开源的官方代码DeepLab在读取TFRecord格式数据集所使用的方法。
配置DeepLab v3
首先,需要将整个工程拉取到本地的workspace。
1. 源码地址:https://github.com/tensorflow/models/tree/master/research/deeplab
2. 将源代码拉取到自己的workspace中。
git clone https://github.com/tensorflow/models.git
3. 测试是否安装配置成功。
# From tensorflow/models/research/
export PYTHONPATH=$PYTHONPATH:`pwd`:`pwd`/slim
# From tensorflow/models/research/
python deeplab/model_test.py
读取数据集代码分析
读取数据集部分的代码出现在以下文件中,以PASCAL_VOC的TFRecord格式数据集进行训练过程为例:
(1)deeplab/train.py
(2)deeplab/datasets/segmentation_dataset.py
(3)deeplab/utils/input_generator.py
1. 输入指令参数
在train.py中可看到以下代码,即需要输入3个参数:train_logdir、tf_initial_checkpoint、dataset_dir。
if __name__ == '__main__':
flags.mark_flag_as_required('train_logdir')
flags.mark_flag_as_required('tf_initial_checkpoint')
flags.mark_flag_as_required('dataset_dir')
tf.app.run()
其中
train_logdir="/deeplab/datasets/pascal_voc_seg/exp/train_on_train_set/train"(训练结束后的checkpoint存放路径)
tf_initial_checkpoint="/deeplab/datasets/cityscapes/deeplabv3_mnv2_pascal_trainval/ model.ckpt-30000.index"(预训练好的checkpoint路径)
dataset_dir="/deeplab/datasets/pascal_voc_seg/tfrecord"(数据集路径)
2. 通过指令输入的参数,获得一个slim.Dataset的实例
2.1 调用segmentation_dataset.py中的get_dataset()函数。
dataset = segmentation_dataset.get_dataset(
FLAGS.dataset, FLAGS.train_split, dataset_dir=FLAGS.dataset_dir)
输入参数如下:
FLAGS.dataset= 'pascal_voc_seg'
FLAGS.train_split= 'train'
FLAGS.dataset_dir='/deeplab/datasets/pascal_voc_seg/tfrecord'(即在1中输入的dataset_dir参数)
2.2 在segmentation_dataset.py中的get_dataset()函数,定义如下:
def get_dataset(dataset_name, split_name, dataset_dir):
(1)首先,进行两个判断。输入的参数中,dataset_name必须是pascal_voc_seg、cityscapes、ade20k其中的一个,否则报错;接着获取数据集的基本信息,如果输入的split_name不是train、train_aug、trainval、val其中的一个,则报错。
if dataset_name not in _DATASETS_INFORMATION:
raise ValueError('The specified dataset is not supported yet.')
splits_to_sizes = _DATASETS_INFORMATION[dataset_name].splits_to_sizes