文章目录
TF object_detection API
这个API是tensorflow官方提供的工程模板,之前曾经尝试过但没有跑通,这次看的比较深入,基本上熟悉了训练、测试、评估的操作流程。实验了VOC2007训练、Pet数据集训练等。下面记录的是研究过程中的一些总结。
使用API训练数据集的一般流程
适当修改下面的对应路径和配置文件
1. 创建tfrecord
python dataset_tools/create_pet_tf_record.py \
--data_dir=/media/han/E/mWork/datasets/Oxford-IIIT_Pet_Dataset \
--output_dir=trainLogs_pets/tfrecord
2. 训练
如果不能运行,那么执行:
# From tensorflow/models/research/
export PYTHONPATH=$PYTHONPATH:`pwd`:`pwd`/slim
python train.py \
--logtostderr \
--train_dir=trainLogs_pets/output \
--pipeline_config_path=trainLogs_pets/ssd_mobilenet_v1_pets.config
3. 将训练得到的权重文件合并为*.pb文件
python export_inference_graph.py --input_type image_tensor \
--pipeline_config_path trainLogs_pets/ssd_mobilenet_v1_pets.config \
--trained_checkpoint_prefix trainLogs_pets/output/model.ckpt-100000 \
--output_directory trainLogs_pets/output
4. 评估
python eval.py \
--logtostderr \
--checkpoint_dir=trainLogs_pets/output \
--eval_dir=trainLogs_pets/eval \
--pipeline_config_path=trainLogs_pets/ssd_mobilenet_v1_pets.config
create_pascal_tf_record.py
ignore_difficult_instances #忽视难例就是不训练难例
- 该文件会将图像文件也编码进.record文件中,所以生成文件比较大
with tf.gfile.GFile(full_path, 'rb') as fid:
encoded_jpg = fid.read()
example = tf.train.Example(features=tf.train.Features(feature={
'image/height': dataset_util.int64_feature(height),
'image/width': dataset_util.int64_feature(width),
'image/filename': dataset_util.bytes_feature(
data['filename'].encode('utf8')),
'image/source_id': dataset_util.bytes_feature(
data['filename'].encode('utf8')),
'image/key/sha256': dataset_util.bytes_feature(key.encode('utf8')),
'image/encoded': dataset_util.bytes_feature(encoded_jpg), #图像raw数据
'image/format': dataset_util.bytes_feature('jpeg'.encode('utf8')), #jpeg格式,也就是说保存的图像是压缩后的大小
'image/object/bbox/xmin': dataset_util.float_list_feature(xmin),
'image/object/bbox/xmax': dataset_util.float_list_featur