原文地址:https://www.tuicool.com/articles/ieQZVfa
前一帖《 TensorFlow练习10: 实现谷歌Deep Dream 》使用到了谷歌训练的Inception模型,本帖就基于Inception模型retrain一个图像分类器。
图像分类器应用广泛,连农业都在使用,如判断黄瓜种类。
本帖使用的训练数据是《 TensorFlow练习9: 生成妹子图(PixelCNN) 》一文中使用的妹子图,最后训练出的分类器可以判断图片是不是妹子图。
首先下载tensorflow源代码:
$ gitclone https://github.com/tensorflow/tensorflow
$ gitcheckoutr0.11 # checkout对应已安装的Tensorflow版本
使用examples中的image_retraining。
训练:
$ pythontensorflow/tensorflow/examples/image_retraining/retrain.py --bottleneck_dirbottleneck --how_many_training_steps 4000 --model_dirmodel --output_graphoutput_graph.pb --output_labelsoutput_labels.txt --image_dirgirl_types/
参数解释参考retrain.py源文件。
大概训练了半个小时:
生成的模型文件和labels文件:
使用训练好的模型:
importtensorflowas tf
importsys
# 命令行参数,传入要判断的图片路径
image_file = sys.argv[1]
#print(image_file)
# 读取图像
image = tf.gfile.FastGFile(image_file, 'rb').read()
# 加载图像分类标签
labels = []
for labelin tf.gfile.GFile("output_labels.txt"):
labels.append(label.rstrip())
# 加载Graph
withtf.gfile.FastGFile("output_graph.pb", 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')
withtf.Session() as sess:
softmax_tensor = sess.graph.get_tensor_by_name('final_result:0')
predict = sess.run(softmax_tensor, {'DecodeJpeg/contents:0': image})
# 根据分类概率进行排序
top = predict[0].argsort()[-len(predict[0]):][::-1]
for indexin top:
human_string = labels[index]
score = predict[0][index]
print(human_string, score)
执行结果:
参考:
- https://www.tensorflow.org/versions/r0.11/how_tos/image_retraining/index.html
- TensorFlow练习4: CNN, Convolutional Neural Networks
- How Convolutional Neural Networks work
Share the post "TensorFlow练习11: 图像分类器 – retrain谷歌Inception模型"