一、数据准备
- animal:http://www.robots.ox.ac.uk/~vgg/data/pets/ (
images.tar.gz
,~765M) - flower:http://www.robots.ox.ac.uk/~vgg/data/flowers/ (
17flowers.tgz
,~58.8M) - plane:http://www.robots.ox.ac.uk/~vgg/data/airplanes_side/airplanes_side.tar (
airplanes_side.tar
,~43.7M) - house:http://www.robots.ox.ac.uk/~vgg/data/houses/houses.tar (
houses.tar
,~16.9M) - guitar:http://www.robots.ox.ac.uk/~vgg/data/guitars/guitars.tar (
guitars.tar
,~24.5M)
二、预训练
首先下载 tensorflow 的源码,GitHub 地址:https://github.com/tensorflow/tensorflow,解压并放在指定位置,比如 D:\TensorFlow
目录下。然后写个批处理文件去执行 TensorFlow 中retrain.py
程序,自动训练模型。
python F:\code\tensorflow-r1.8\tensorflow\examples\image_retraining\retrain.py ^
--bottleneck_dir bottleneck ^
--how_many_training_steps 200 ^
--model_dir inception_model/ ^
--output_graph output_graph.pb ^
--output_labels output_labels.txt ^
--image_dir data/
pause
三、测试模型
# coding: utf-8
import tensorflow as tf
import os
import numpy as np
import re
from PIL import Image
import matplotlib.pyplot as plt
lines = tf.gfile.GFile('F:\\code\\retrain\\output_labels.txt').readlines()
uid_to_human = {}
# 一行一行读取数据
for uid, line in enumerate(lines):
# 去掉换行符
line=line.strip('\n')
uid_to_human[uid] = line
def id_to_string(node_id):
if node_id not in uid_to_human:
return ''
return uid_to_human[node_id]
# 创建一个图来存放google训练好的模型
with tf.gfile.FastGFile('F:\\code\\retrain\\output_graph.pb', 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')
with tf.Session() as sess:
softmax_tensor = sess.graph.get_tensor_by_name('final_result:0')
#遍历目录
for root,dirs,files in os.walk('F:\\code\\retrain\\images\\'): #指定测试图片的位置
for file in files:
#载入图片
image_data = tf.gfile.FastGFile(os.path.join(root,file), 'rb').read()
predictions = sess.run(softmax_tensor,{'DecodeJpeg/contents:0': image_data})#图片格式是jpg格式
predictions = np.squeeze(predictions)#把结果转为1维数据
#打印图片路径及名称
image_path = os.path.join(root,file)
print(image_path)
#排序
top_k = predictions.argsort()[::-1]
print(top_k)
for node_id in top_k:
#获取分类名称
human_string = id_to_string(node_id)
#获取该分类的置信度
score = predictions[node_id]
print('%s (score = %.5f)' % (human_string, score))
print()
#显示图片
img=Image.open(image_path)
plt.imshow(img)
plt.axis('off')
plt.show()