https://github.com/rishizek/tensorflow-deeplab-v3
一、基本流程
1. Setup
2. Train
- 将原始数据转换为TFRecord格式(可以提高训练速度)
python create_pascal_tf_record.py --data_dir DATA_DIR \
--image_data_dir IMAGE_DATA_DIR \
--label_data_dir LABEL_DATA_DIR
- 开始训练:
python train.py --model_dir MODEL_DIR
--pre_trained_model PRE_TRAINED_MODEL
--pre_trained_model
包含了预训练的 Resnet 模型,model_dir
包含了该程序训练的 DeepLab-V3 checkpoints。
3. Inference
对模型进行测试:
python inference.py --data_dir DATA_DIR
--infer_data_list INFER_DATA_LIST
--model_dir MODEL_DIR
二、TFRecord 转换代码
create_pascal_tf_record.py
1. 图片转为 tf.Example
def dict_to_tf_example(image_path, label_path):
2. 用 sample 建立 TFRecord
def create_tf_record(output_filename,
image_dir,
label_dir,
examples):