参考:
- Fine-tuning a model from an existing checkpoint
- TF-Slim 实现模型迁移/微调
- Tensorflow-pb保存与导入
- 迁移学习/fine-tuning
- Tensorflow-变量保存与导入
完整程序:点击这里
数据下载并转成tfrecord格式
$ DATA_DIR=/tmp/data/flowers
$ CUDA_VISIBLE_DEVICES=1 python3 download_and_convert_data.py \
--dataset_name=flowers \
--dataset_dir="${DATA_DIR}"
# flower_photos文件结构
<flower_photos>
|--- daisy
| |--- *.jpg
|--- dandelion
| |--- *.jpg
|--- roses
| |--- *.jpg
|--- sunflowers
| |--- *.jpg
|--- tulips
|___ |--- *.jpg
# 共5个类别
# 数据转成tfrecord格式
$ ls ${DATA_DIR}
flowers_train-00000-of-00005.tfrecord
...
flowers_train-00004-of-00005.tfrecord
flowers_validation-00000-of-00005.tfrecord
...
flowers_validation-00004-of-00005.tfrecord
labels.txt
下载pre-train model
$ CHECKPOINT_DIR=/tmp/checkpoints
$ mkdir ${CHECKPOINT_DIR}
$ wget http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz
$ tar -xvf inception_v3_2016_08_28.tar.gz
$ mv inception_v3.ckpt ${CHECKPOINT_DIR}
$ rm inception_v3_2016_08_28.tar.gz
执行迁移/微调
$ DATASET_DIR=/tmp/data/flowers # tfrecord数据路径
$ TRAIN_DIR=/tmp/flowers-models/inception_v3 # 新网络模型保存位置
$ CHECKPOINT_PATH=/tmp/checkpoints/inception_v3.ckpt # pre-train model 路径
$ CUDA_VISIBLE_DEVICES=1 python3 train_image_classifier.py \
--train_dir=${TRAIN_DIR} \
--dataset_dir=${DATASET_DIR} \
--dataset_name=flowers \
--dataset_split_name=train \
--model_name=inception_v3 \
--checkpoint_path=${CHECKPOINT_PATH} \
--checkpoint_exclude_scopes=InceptionV3/Logits,InceptionV3/AuxLogits \
--trainable_scopes=InceptionV3/Logits,InceptionV3/AuxLogits
# ---------说明----------------
--checkpoint_exclude_scopes # 第一次不加载这些参数
--trainable_scopes # 重新训练这部分参数
评估模型性能
要评估模型的性能(无论是预训练还是自己的),可以使用eval_image_classifier.py
脚本,如下所示。
# 评估上面的迁移学习的模型
CHECKPOINT_FILE =/tmp/flowers-models/inception_v3 # Example
DATASET_DIR=/tmp/data/flowers
$ CUDA_VISIBLE_DEVICES=1 python3 eval_image_classifier.py \
--alsologtostderr \
--checkpoint_path=${CHECKPOINT_FILE} \
--dataset_dir=${DATASET_DIR} \
--dataset_name=flowers \
--dataset_split_name=validation \
--model_name=inception_v3
导出推理图
# 保存图表并保存变量参数
from tensorflow.python.framework import graph_util
var_list=tf.global_variables()
constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def,output_node_names=[var_list[i].name for i in range(len(var_list))]) # 保存图表并保存变量参数
tf.train.write_graph(constant_graph, './output', 'expert-graph.pb', as_text=False)
# -----方式2-------------------
from tensorflow.python.framework import graph_util
var_list=tf.global_variables()
constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def,output_node_names=[var_list[i].name for i in range(len(var_list))])
with tf.gfile.FastGFile(logdir+'expert-graph.pb', mode='wb') as f:
f.write(constant_graph.SerializeToString())
# 只保留图表
graph_def = tf.get_default_graph().as_graph_def()
with gfile.GFile('./output/export.pb', 'wb') as f:
f.write(graph_def.SerializeToString())
# 或者
tf.train.write_graph(graph_def, './output', 'expert-graph.pb', as_text=False)
$ CUDA_VISIBLE_DEVICES=1 python3 export_inference_graph.py \
--alsologtostderr \
--model_name=inception_v3 \
--output_file=/tmp/inception_v3_inf_graph.pb
自己的数据实现迁移/微调
1、修改下download_and_convert_data.py
2、根据download_and_convert_flowers.py
重新一个自己数据的脚本download_and_convert_my_data.py
3、 根据flowers.py
改写成自己数据的脚本my_data.py
4、修改dataset_factory.py
5、根据需要修改模型文件inception_v3.py
,如:增加或者减少层
数据转换
$ DATA_DIR=./train # 自己数据的目录(事先下载好)
$ CUDA_VISIBLE_DEVICES=1 python3 download_and_convert_data.py \
--dataset_name=my_data \
--dataset_dir="${DATA_DIR}"
# train 目录结构
<train>
|--- 0001
| |--- *.jpg
|--- 0002
| |--- *.jpg
|--- 0003
| |--- *.jpg
|....
下载pre-train model
与上面相同
$ CHECKPOINT_DIR=/tmp/checkpoints
$ mkdir ${CHECKPOINT_DIR}
$ wget http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz
$ tar -xvf inception_v3_2016_08_28.tar.gz
$ mv inception_v3.ckpt ${CHECKPOINT_DIR}
$ rm inception_v3_2016_08_28.tar.gz
执行迁移/微调
$ DATASET_DIR=./train # tfrecord数据路径
$ TRAIN_DIR=./inception_v3 # 新网络模型保存位置
$ CHECKPOINT_PATH=/tmp/checkpoints/inception_v3.ckpt # pre-train model 路径
$ CUDA_VISIBLE_DEVICES=1 python3 train_image_classifier.py \
--train_dir=${TRAIN_DIR} \
--dataset_dir=${DATASET_DIR} \
--dataset_name=my_data \
--dataset_split_name=train \
--model_name=inception_v3 \
--checkpoint_path=${CHECKPOINT_PATH} \
--checkpoint_exclude_scopes=InceptionV3/Logits,InceptionV3/AuxLogits \
--trainable_scopes=InceptionV3/Logits,InceptionV3/AuxLogits
评估模型性能
# 评估上面的迁移学习的模型
CHECKPOINT_FILE =./inception_v3 # Example
DATASET_DIR=./train
$ CUDA_VISIBLE_DEVICES=1 python3 eval_image_classifier.py \
--alsologtostderr \
--checkpoint_path=${CHECKPOINT_FILE} \
--dataset_dir=${DATASET_DIR} \
--dataset_name=my_data \
--dataset_split_name=validation \
--model_name=inception_v3
注:注意路径对应,不要弄混了