程序运行环境是tensorflow-cpu1.13.1
1.准备数据集
准备几个类别的图像数据集存放在各自类别路径下
如上图,将五个类别的图像数据分别存放在各自的文件目录,这里每个类别存放了500张.jpg图像文件,命名是0001.jpg - 0500.jpg
2.下载inception_model
这里需要将谷歌的inception_model文件放置在程序相同的路径下,可以从这里https://download.csdn.net/download/cyj5201314/16581511 下载model
3.下载retrain.py
可以从这里下载 https://download.csdn.net/download/cyj5201314/16603267
4. 指定训练参数
- 训练集的路径,我这里是存放在程序当前路径的data目录里,给该参数的default赋值路径字符串即可
- 训练结束后图模型保存路径,我这里直接保存在程序当前路径下
- 输出的标签存储路径,这里同样直接保存在当前程序路径
- 一共训练多少步, 这里指定200步
- 卷积层最终输出的张量保存路径,这里实际上只训练最后一个全连接层,卷积层全部使用谷歌训练好的参数,所以这里相当于计算全部训练图像的卷积层输出,保存到本地
- inception3模型路径,这里将指定第二步下载好的模型路径,这里直接放在程序当前路径
5训练模型
直接运行retrain.py即可
计算全部训练集的卷积层输出如下: 这里计算全部输出大概需要半个小时
训练结束后验证准确率和测试准确率如下: 训练速度很快,因为只训练最后的一个全连接层
在508个测试图像上的准确率达到100%
6.用模型识别图像
将要识别的图像放在程序当前路径下的images文件目录
运行predict.py即可
import os
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
import tensorflow as tf
import numpy as np
import re
from PIL import Image
import matplotlib.pyplot as plt
lines = tf.gfile.GFile('output_labels.txt').readlines()
uid_to_human = {}
# 一行一行读取数据
for uid,line in enumerate(lines) :
#去掉换行符
line=line.strip('\n')
uid_to_human[uid] = line
# 分类编号变成描述
def id_to_string(node_id):
if node_id not in uid_to_human:
return ''
return uid_to_human[node_id]
# 创建一个图来存放训练好的模型
with tf.gfile.FastGFile('output_graph.pb', 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')
with tf.Session() as sess:
# final_result为输出tensor的名字
softmax_tensor = sess.graph.get_tensor_by_name('final_result:0')
# 遍历目录
for root,dirs,files in os.walk('images/'):
for file in files:
# 载入图片
image_data = tf.gfile.FastGFile(os.path.join(root,file), 'rb').read()
# 把图像数据传入模型获得模型输出结果
predictions = sess.run(softmax_tensor,{'DecodeJpeg/contents:0': image_data})
# 把结果转为1维数据
predictions = np.squeeze(predictions)
# 打印图片路径及名称
image_path = os.path.join(root,file)
print(image_path)
# 显示图片
img = Image.open(image_path)
plt.imshow(img)
plt.axis('off')
plt.show()
# 排序
top_k = predictions.argsort()[::-1]
for node_id in top_k:
# 获取分类名称
human_string = id_to_string(node_id)
# 获取该分类的置信度
score = predictions[node_id]
print('%s (score = %.5f)' % (human_string, score))
print()
训练和测试集数据可在 https://download.csdn.net/download/cyj5201314/16604481 下载