TensorFlow学习笔记(七) 使用训练好的inception_v3模型预测分类图片

下载需要练习的inception模型并看起流程

  1. import tensorflow as tf  
  2. import os  
  3. import tarfile  
  4. import requests  
  5.   
  6. #inception_v3模型下载  
  7. inception_pretrain_model_url = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz'  
  8.   
  9. # 模型存放地址  
  10. inception_pretrain_model_dir = "inception_model"  
  11. if not os.path.exists(inception_pretrain_model_dir):  
  12.     os.makedirs(inception_pretrain_model_dir)  
  13.   
  14. #获取文件名,以及文件路径  
  15. filename = inception_pretrain_model_url.split('/')[-1]  
  16. filepath = os.path.join(inception_pretrain_model_dir, filename)  
  17.   
  18. #下载模型  
  19. if not os.path.exists(filepath):  
  20.     print('download: ', filename)  
  21.     r = requests.get(inception_pretrain_model_url, stream=True)  
  22.     with open(filepath,'wb') as f:  
  23.         for chunk in r.iter_content(chunk_size=1024):  
  24.             if chunk:  
  25.                 f.write(chunk)  
  26. print("finishn: ", filename)  
  27.   
  28. #解压文件  
  29. tarfile.open(filepath, 'r:gz').extractall(inception_pretrain_model_dir)  
  30.   
  31. #模型结构存放文件  
  32. log_dir = 'inception_log'  
  33. if not os.path.exists(log_dir):  
  34.     os.makedirs(log_dir)  
  35.   
  36. #classify_image_graph_def.pb为google训练好的模型  
  37. inception_graph_def_file = os.path.join(inception_pretrain_model_dir, 'classify_image_graph_def.pb')  
  38. with tf.Session() as sess:  
  39.     #创建一个图来存放google训练好的模型,load graph 具体实现方法看下面的链接  
  40.     with tf.gfile.FastGFile(inception_graph_def_file, 'rb') as f:  
  41.         graph_def = tf.GraphDef()  
  42.         graph_def.ParseFromString(f.read())  
  43.         tf.import_graph_def(graph_def, name='')  
  44.   
  45.     #保存图的结构  
  46.     writer = tf.summary.FileWriter(log_dir, sess.graph)  
  47.     writer.close() 
其结构图孺如下




使用inception模型检测图片

里面主要写了labels排序等 的实现,以及利用训练好的模型识别图片的实现


  1. import tensorflow as tf  
  2. import os  
  3. import numpy as np  
  4. import re  
  5. from PIL import Image  
  6. import matplotlib.pyplot as plt  
  7.   
  8. class NodeLookup(object):  
  9.     def __init__(self):  
  10.         label_lookup_path = 'inception_model/imagenet_2012_challenge_label_map_proto.pbtxt'  
  11.         uid_lookup_path = 'inception_model/imagenet_synset_to_human_label_map.txt'  
  12.         self.node_lookup = self.load(label_lookup_path, uid_lookup_path)  
  13.   
  14.     def load(self, label_lookup_path, uid_lookup_path):  
  15.         #加载分类字符串n ------ 对应分类名称的文件  
  16.         proto_as_ascii_lines = tf.gfile.GFile(uid_lookup_path).readlines()  
  17.         uid_to_human = {}  
  18.         #一行一行读取数据  
  19.         for line in proto_as_ascii_lines :  
  20.             #去掉换行符  
  21.             line = line.strip('\n')  
  22.             #按照‘\t’分割  
  23.             parsed_items = line.split('\t')  
  24.             #获取分类编号和分类名称  
  25.             uid = parsed_items[0]  
  26.             human_string = parsed_items[1]  
  27.             #保存编号字符串-----与分类名称映射关系  
  28.             uid_to_human[uid] = human_string  
  29.   
  30.   
  31.         #加载分类字符串n ----- 对应分类编号1-1000的文件  
  32.         proto_as_ascii_lines = tf.gfile.GFile(label_lookup_path).readlines()  
  33.         node_id_to_uid = {}  
  34.         for line in proto_as_ascii_lines :  
  35.             if line.startswith('  target_class:'):  
  36.                 #获取分类编号1-1000  
  37.                 target_class = int(line.split(': ')[1])  
  38.             if line.startswith('  target_class_string:'):  
  39.                 #获取编号字符串n****  
  40.                 target_class_string = line.split(': ')[1]  
  41.                 #保存分类编号1-1000与编号字符串n****的映射关系  
  42.                 node_id_to_uid[target_class] = target_class_string[1:-2]  
  43.   
  44.   
  45.         #建立分类编号1-1000对应分类名称的映射关系  
  46.         node_id_to_name = {}  
  47.         for key, val in node_id_to_uid.items():  
  48.             #获取分类名称  
  49.             name = uid_to_human[val]  
  50.             #建立分类编号1-1000到分类名称的映射关系  
  51.             node_id_to_name[key] = name  
  52.         return node_id_to_name  
  53.   
  54.     #传入分类编号1-1000返回分类名称  
  55.     def id_to_string(self, node_id):  
  56.         if node_id not in self.node_lookup:  
  57.             return ''  
  58.         return self.node_lookup[node_id]  
  59.   
  60. #创建一个图来存放google训练好的模型  #2 load graph  
  61. with tf.gfile.FastGFile('inception_model/classify_image_graph_def.pb', 'rb') as f:  
  62.     graph_def = tf.GraphDef()  
  63.     graph_def.ParseFromString(f.read())  
  64.     tf.import_graph_def(graph_def, name='')  
  65.   
  66.   
  67. with tf.Session() as sess:  
  68.     softmax_tensor = sess.graph.get_tensor_by_name('softmax:0')  
  69.     #遍历目录  
  70.     for root, dirs, files in os.walk('images/'):  
  71.         for file in files:  
  72.             #载入图片  
  73.             image_data = tf.gfile.FastGFile(os.path.join(root,file), 'rb').read()  
  74.             predictions = sess.run(softmax_tensor, {'DecodeJpeg/contents:0': image_data})#图片格式是jpg格式  
  75.             predictions = np.squeeze(predictions)#把结果转为1维  
  76.   
  77.             #打印图片路径及名称  
  78.             image_path = os.path.join(root,file)  
  79.             print(image_path)  
  80.             #显示图片  
  81.             img = Image.open(image_path)  
  82.             plt.imshow(img)  
  83.             plt.axis('off')  
  84.             plt.show()  
  85.   
  86.             #排序  
  87.             top_k = predictions.argsort()[-5:][::-1]  
  88.             node_lookup = NodeLookup()  
  89.             for node_id in top_k:  
  90.                 #获取分类名称  
  91.                 human_string = node_lookup.id_to_string(node_id)  
  92.                 #获取该分类的置信度  
  93.                 score = predictions[node_id]  
  94.                 print('%s (score = %.5f)' % (human_string, score))  
  95.             print()  




识别的准确率还是可以的,如果遇到不认识的狗子,识别一下就好啦,哈哈


  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值