1. 安装Tensorflow Hub
pip install "tensorflow>=1.7.0"
pip install tensorflow-hub
2. 获取花卉照片集
curl -LO http://download.tensorflow.org/example_images/flower_photos.tgz
tar xzf flower_photos.tgz
3. 下载训练代码
mkdir retrain
cd retrain/
curl -LO https://github.com/tensorflow/hub/raw/master/examples/image_retraining/retrain.py
4. 翻墙(不然训练时下载预先训练的模型会失败)
curl -LO https://raw.githubusercontent.com/getlantern/lantern-binaries/master/lantern-installer-64-bit.deb
dpkg -i lantern-installer-64-bit.deb
lantern
5. 训练
python retrain.py --image_dir ../flower_photos
注: 模型文件:/tmp/output_graph.pb, 标签文件文件:/tmp/output_labels.
6. 测试
git clone https://github.com/tensorflow/tensorflow
cd tensorflow
bazel build tensorflow/examples/label_image:label_image
bazel-bin/tensorflow/examples/label_image/label_image \
--image=/test.jpg \
--graph=/tmp/output_graph.pb \
--labels=/tmp/output_labels.txt \
--input_layer="Placeholder" \
--output_layer="final_result"
输出结果:
daisy (0): 0.86599
sunflowers (3): 0.131518
roses (2): 0.00105852
dandelion (1): 0.000859526
tulips (4): 0.000574312
注: 相看相关参数
bazel-bin/tensorflow/tools/graph_transforms/summarize_graph --in_graph=/tmp/output_graph.pb
7. 使用TensorBoard可视化再培训(可选)
tensorboard --logdir /tmp/retrain_logs
注:训练开始后执行以上命令,并在浏览器打开localhost:6006