使用Inception v3实现图像识别

官方相关模型及文件下载:http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
下载的文件列表:
在这里插入图片描述
里面有训练好的pb模型,以及标签特征对应文件
在这里插入图片描述

载入相关库:

import tensorflow as tf
import os
import numpy as np
import re
from PIL import Image
import matplotlib.pyplot as plt

code:

class NodeLookup(object):
    def __init__(self):
    #加载对应的分类文件
        label_path = "inception_model/imagenet_2012_challenge_label_map_proto.pbtxt"
        uid_path = "inception_model/imagenet_synset_to_human_label_map.txt"
        self.node_lookup = self.load(label_path, uid_path)
    
    def load(self, label_path, uid_path):
    #读取pbtxt文件,按行读取
        proto_label_line = tf.gfile.GFile(label_path).readlines()
        #将读取的内容以字典的方式存储,key-value
        node_to_label_uid = {}
        #遍历每一行
        for line in proto_label_line:
        #去掉换行符
            line = line.strip('\n')
            #如果每一行以target_class开头:则对应的内容为  target_class: 442
            if line.startswith("  target_class:"):
                target_class = int(line.split(": ")[1])#442
            if line.startswith("  target_class_string:"):
            #包含引号
                target_class_string = line.split(": ")[1]
                #  target_class_string: "n01494475",取中间的字母数字部分,不包含引号
                node_to_label_uid[target_class] = target_class_string[1:-1]
        
        proto_uid_line = tf.gfile.GFile(uid_path).readlines()
        uid_to_label_human = {}
        for line in proto_uid_line:
            line = line.strip('\n')
            #分割
            parse_items = line.split("\t")
            uid = parse_items[0]
            target = parse_items[1]
            uid_to_label_human[uid] = target
            #uid_to_label_human[parse_items[0]] = parse_items[1]
            
        #重新对应如种类编号442对应类别cat(假设)
        node_id_to_name = {}
        for key, val in node_to_label_uid.items():
            name = uid_to_label_human[val]
            node_id_to_name[key] = name
        return node_id_to_name
    
    def id_to_string(self, node_id):
        if node_id not in self.node_lookup:
            return ''
        return self.node_lookup[node_id]

#读取模型,重新创建相应的图,固定书写模式
with tf.gfile.GFile("inception_model/classify_image_graph_def.pb",'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    tf.import_graph_def(graph_def, name="")
    
with tf.Session() as sess:
    softmax_tensor = sess.graph.get_tensor_by_name("softmax:0")
    for root, dirs, files in os.walk("images/"):
        for file in files:
            image_data  = tf.gfile.GFile(os.path.join(root, file), 'rb').read()
            predict = sess.run(softmax_tensor, {"DecodeJpeg/contents:0":image_data})#图片是jpg格式
            predict = np.squeeze(predict)#将结果转为一维数据
            print(predict.shape)
            
            image_path = os.path.join(root, file)
            print(image_path)
            img = Image.open(image_path)
            plt.imshow(img)
            plt.axis("off")
            plt.show()
            
            top_pre = predict.argsort()[-3:][::-1]#升序排列,选择最后三个数据,再倒叙,即降序
            node_lookup = NodeLookup()
            for node_id in top_pre:
                human_string = node_lookup.id_to_string(node_id)
                score = predict[node_id]
                print("%s (score=%.5f)" %(human_string, score))
            print("*************************")   


结果:
在这里插入图片描述

bug

1、readline()&readlines(),这里使用readlines(),否则只读取一行数据

---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
<ipython-input-21-74081abf60d9> in <module>
     74             #对预测结果排序,从小到大排序,取最后几个,[::-1]:对最后的五个概率值取倒叙
     75             top_k = predictions.argsort()[-5:][: :-1]
---> 76             node_lookup = NodeLookUp()
     77             for node_id in top_k:
     78                 #获取分类名称

<ipython-input-21-74081abf60d9> in __init__(self)
      4         label_lookup_path = "inception_model/imagenet_2012_challenge_label_map_proto.pbtxt"
      5         uid_lookup_path = "inception_model/imagenet_synset_to_human_label_map.txt"
----> 6         self.node_lookup = self.load(label_lookup_path, uid_lookup_path)
      7 
      8     #载入数据进行处理

<ipython-input-21-74081abf60d9> in load(self, label_lookup_path, uid_lookup_path)
     19             #用键值对存储数据
     20             uid = parsed_items[0]
---> 21             human_string = parsed_items[1]
     22             #保存id和字符串为对应的映射关系
     23             uid_to_human[uid] = human_string

IndexError: list index out of range


2、载入模型时没有对其采用二进制读取,‘rb’

---------------------------------------------------------------------------
UnicodeDecodeError                        Traceback (most recent call last)
<ipython-input-27-052cb1c2f17f> in <module>
     40 with tf.gfile.GFile("inception_model/classify_image_graph_def.pb") as f:
     41     graph_def = tf.GraphDef()
---> 42     graph_def.ParseFromString(f.read())
     43     tf.import_graph_def(graph_def, name="")
    UnicodeDecodeError: 'utf-8' codec can't decode byte 0xbb in position 1: invalid start byte

‘r’:默认值,表示从文件读取数据。
‘w’:表示要向文件写入数据,并截断以前的内容
‘a’:表示要向文件写入数据,添加到当前内容尾部
‘r+’:表示对文件进行可读写操作(删除以前的所有数据)
‘r+a’:表示对文件可进行读写操作(添加到当前文件尾部)
‘b’:表示要读写二进制数据

3、未定义,tab,代码格式问题

if line.startswith("  target_class_string:"):
                target_class_string = line.split(": ")[1]
node_to_label_uid[target_class] = target_class_string[1:-1]

<ipython-input-28-ecfdf82b86d1> in load(self, label_path, uid_path)
     14             if line.startswith("  target_class_string:"):
     15                 target_class_string = line.split(": ")[1]
---> 16             node_to_label_uid[target_class] = target_class_string
     17 
     18         proto_uid_line = tf.gfile.GFile(uid_path).readlines()

UnboundLocalError: local variable 'target_class_string' referenced before assignment
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值