1.下载源码包,并准备图片
将TF1.5源码包和inception_model 模型下载下来
在码云上将TF1.5的源码包下载下来
http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
图片样本参考网址–https://www.robots.ox.ac.uk/~vgg/data/
图片至少两个类别且每个种类最少200张,放在目录data/train/
2.执行retrain.bat 训练模型
现将所在盘符下的tmp文件夹中的内容删除,防止训练报错
python D:/AI/tensorflow-v1.5.0/tensorflow-v1.5.0/tensorflow/examples/image_retraining/retrain.py ^
--bottleneck_dir bottleneck ^# 图片转换为向量地址
--how_many_training_steps 2 ^ #训练的周期
--model_dir C:/Users/Administrator/Tensorflowcx/inception_model/ ^ #模型加载路径
--output_graph output_graph.pb ^ #输出训练好的模型
--output_labels output_labels.txt ^ #将训练好的标签参数输出
--image_dir data/train/ # 你要分类的图片
pause
保存为retrain.bat
3.测试训练好的模型
在jupyter Notebook里新建python文件并执行
#测试训练好的模型
import tensorflow as tf
import os
import numpy as np
import re
from PIL import Image
import matplotlib.pyplot as plt
# In[2]:
lines = tf.gfile.GFile('retrain/output_labels.txt').readlines()
uid_to_human = {}
#一行一行读取数据
for uid,line in enumerate(lines) :
#去掉换行符
line=line.strip('\n')
uid_to_human[uid] = line
def id_to_string(node_id):
if node_id not in uid_to_human:
return ''
return uid_to_human[node_id]
#创建一个图来存放google训练好的模型
with tf.gfile.FastGFile('retrain/output_graph.pb', 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')
with tf.Session() as sess:
softmax_tensor = sess.graph.get_tensor_by_name('final_result:0')
#遍历目录
for root,dirs,files in os.walk('retrain/data/train/1/'):
for file in files:
#载入图片
image_data = tf.gfile.FastGFile(os.path.join(root,file), 'rb').read()
predictions = sess.run(softmax_tensor,{'DecodeJpeg/contents:0': image_data})#图片格式是jpg格式
predictions = np.squeeze(predictions)#把结果转为1维数据
#打印图片路径及名称
image_path = os.path.join(root,file)
print(image_path)
#显示图片
img=Image.open(image_path)
plt.imshow(img)
plt.axis('off')
plt.show()
#排序
top_k = predictions.argsort()[::-1]
print(top_k)
for node_id in top_k:
#获取分类名称
human_string = id_to_string(node_id)
#获取该分类的置信度
score = predictions[node_id]
print('%s (score = %.5f)' % (human_string, score))
print()