读.pbtx
文件tf.gfile.GFile(label_lookup_path,'r').readlines()
imagenet_2012_challenge_label_map_proto.pbtxt
文件内容
label_lookup_path = '../inception_model/imagenet_2012_challenge_label_map_proto.pbtxt'
fid = tf.gfile.GFile(label_lookup_path,'r').readlines()
num_2_n_string = {}
count = 0
for i,line in enumerate(fid):
if line.startswith(' target_class:'):
num = line.strip().split(':')[-1]
num = int(num)
# print(type(num))
if line.startswith(' target_class_string:'):
n_string = eval(line.strip().split(':')[-1])#eval去掉引号
# print(n_string)
num_2_n_string[num] = n_string
count+= 1
加载pb模型
f = tf.gfile.GFile('../inception_model/classify_image_graph_def.pb','rb')
#创建一个图
graph_def = tf.GraphDef()
# dir(graph_def)
#将模型载入图中
graph_def.ParseFromString(f.read())
#将图在如到当前环境中
tf.import_graph_def(graph_def,name='')
根据pb文件读出文件结构,并测试图片
- 创建会话并设置保存路径
sess = tf.Session()
LOGDIR='./logs/'
train_writer = tf.summary.FileWriter(LOGDIR)
train_writer.add_graph(sess.graph)
- 取出softmax层
softmax = sess.graph.get_tensor_by_name('softmax:0')
- 读取要测试的图像及预测
image_data = tf.gfile.GFile('../images/car.jpg', 'rb').read()
predict = sess.run(softmax,feed_dict={'DecodeJpeg/contents:0':image_data})
输出测试结果
num = tf.argmax(predict,1)
sess.run(num)
num_2_description[274]
完整代码
import tensorflow as tf
import numpy as np
label_lookup_path = '../inception_model/imagenet_2012_challenge_label_map_proto.pbtxt'
fid = tf.gfile.GFile(label_lookup_path,'r').readlines()
num_2_n_string = {}
count = 0
for i,line in enumerate(fid):
if line.startswith(' target_class:'):
num = line.strip().split(':')[-1]
num = int(num)
# print(type(num))
if line.startswith(' target_class_string:'):
n_string = eval(line.strip().split(':')[-1])#eval去掉引号
# print(n_string)
num_2_n_string[num] = n_string
count+= 1
print(num_2_n_string[396],count,len(num_2_n_string.keys()))
n_string_description_path = '../inception_model/imagenet_synset_to_human_label_map.txt'
n_string_2_description={}
fo = open(n_string_description_path,'r')
for i,line in enumerate(fo):
line = line.strip()
if line:
n_string,description = line.split('\t')
# print(len(line.split('\t')))
# break
n_string_2_description[n_string]=description
# if i>10:
# break
n_string_2_description['n00004475']
num_2_description = {}
for num in num_2_n_string.keys():
n_string = num_2_n_string[num]
if n_string in n_string_2_description:
num_2_description[num]=n_string_2_description[n_string]
print(len(num_2_description),len(n_string_2_description),len(num_2_n_string))
f = tf.gfile.GFile('../inception_model/classify_image_graph_def.pb','rb')
#创建一个图
graph_def = tf.GraphDef()
# dir(graph_def)
#将模型载入图中
graph_def.ParseFromString(f.read())
#将图在如到当前环境中
tf.import_graph_def(graph_def,name='')
# # 根据pb文件读出文件结构
sess = tf.Session()
LOGDIR='./logs/'
train_writer = tf.summary.FileWriter(LOGDIR)
train_writer.add_graph(sess.graph)
softmax = sess.graph.get_tensor_by_name('softmax:0')
image_data = tf.gfile.GFile('../images/car.jpg', 'rb').read()
predict = sess.run(softmax,feed_dict={'DecodeJpeg/contents:0':image_data})
num = tf.argmax(predict,1)
sess.run(num)
num_2_description[274]