.tflite在PC端的检测程序,python实现

# -*- coding:utf-8 -*-
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "-1" 
import numpy as np
import time
from keras.preprocessing import image
import tensorflow as tf
from keras.applications.imagenet_utils import  preprocess_input

#test_image_dir = '../data/trainingData/pos'
# msoldierDetection/models/retrained_graph_MobileNet.tflite
model_path = "../models/retrained_graph_MobileNet.tflite"
#model_path = "../models/retrained_graph_RedHeadFile_and_Soldier.tflite"
# Load TFLite model and allocate tensors.
interpreter = tf.lite.Interpreter(model_path=model_path)
interpreter.allocate_tensors()

# Get input and output tensors.
input_details = interpreter.get_input_details()
print(str(input_details))
output_details = interpreter.get_output_details()
print(str(output_details))

def load_labels(label_file):
  label = []
  proto_as_ascii_lines = tf.gfile.GFile(label_file).readlines()
  for l in proto_as_ascii_lines:
    label.append(l.rstrip())
  return label

def file_name(file_dir):
    for root, dirs, files in os.walk(file_dir):
        # dirs  # 当前路径下所有子目录
        return root, files  # 当前路径下所有非目录子文件

 
all_imgs = file_name('..\\data\\trainingData\\pos')
#all_imgs = file_name('D:\Imagenet2012\ILSVRC2012_img_test')
#all_imgs = file_name('../data/wrong/other')
#all_imgs = file_name('D:\\working\\Python\\readHeadFileTraining\\Testflower\\tf_files\\flower_photos\\redheadfile')
# with tf.Session( ) as sess:
if model_path.find("Soldier") != -1:
    label_file = "../models/retrained_labels.txt"
else:
    label_file = "../models/retrained_labels_1000.txt"
labels = load_labels(label_file)

if 1:
    findnum=0
    model_interpreter_time = 0
    start_time = time.time()
    # 遍历文件

    for idx, img_name in enumerate(all_imgs[1]):
    #for file in file_list:
        full_path = all_imgs[0] + '/' + img_name




        # # 只要黑白的,大小控制在(28,28)
        # img = cv2.imread(full_path)
        # res_img = cv2.resize(img, (224, 224))
        # # 变成长784的一维数据
        # new_img = res_img.reshape(224,224,3)
        #
        # # 增加一个维度,变为 [1, 784]
        # image_np_expanded = np.expand_dims(new_img, axis=0)
        # image_np_expanded = image_np_expanded.astype('float32')  # 类型也要满足要求
        #full_path = "../data/wrong/h.jpg"
        print(str(idx)+' full_path:{}'.format(full_path))
        img = image.load_img(full_path, target_size=(224, 224))
        x = image.img_to_array(img)
        x = np.expand_dims(x, axis=0)
        x = preprocess_input(x,mode='tf')
        #x = x.astype('float32')  # 类型也要满足要求

        # 填装数据
        model_interpreter_start_time = time.time()
        interpreter.set_tensor(input_details[0]['index'], x)

        # 注意注意,我要调用模型了
        interpreter.invoke()
        output_data = interpreter.get_tensor(output_details[0]['index'])
        model_interpreter_time += time.time() - model_interpreter_start_time

        # 出来的结果去掉没用的维度
        results = np.squeeze(output_data)
        #print('result:{}'.format(result))
        # print('result:{}'.format(sess.run(output, feed_dict={newInput_X: image_np_expanded})))
        #print(partition_arg_topK(result, 3 ))
        top_k = results.argsort()[-5:][::-1]
        template = "{} (score={:0.5f})"
        findKind = -1
        objinfor=''
        for i in top_k:
            #print(template.format(labels[i], results[i])) 注: or or labels[i].find('projectile') != -1 \ labels[i].find('pickelhaube') != -1 \groom, bridegroom,'kimono',labels[i].find('cuirass') != -1 or和服,有可能oror labels[i].find('rifle')!=-1  labels[i].find('vestment') != -1labels[i].find('shield') != -1  oror labels[i].find('helmet')!=-1 \or labels[i].find('prison') != -1  \
            if labels[i].find('military uniform')!=-1 \
                    or labels[i].find('gun')!=-1 \
                    or labels[i].find('bulletproof')!=-1 \
                    or labels[i].find('army tank') != -1 \
                    or labels[i].find('cannon')!=-1 :
                if results[i]>0.03:
                    findKind = 1
                    objinfor=objinfor+labels[i]+('%.3f' % results[i])


            # if labels[i].find('website')!=-1 or labels[i].find('menu')!=-1 \
            #         or labels[i].find('screen')!=-1 or labels[i].find('bulletproof')!=-1 \
            #         or labels[i].find('envelope') != -1 \
            #         :
            #     findKind = 2

        objinfor=objinfor.replace("'","")
        objinfor = objinfor.replace(":", " ")
        objinfor = objinfor.replace("\n", "")
        if findKind is -1:
            image.save_img('../data/result/neg/wrong_val_' + str(idx) + "_" + str(results[2]) + '.jpg', img)
            # for i in top_k:
            #     print(template.format(labels[i], results[i]))
        else:
            try:
                findnum = findnum + 1
                image.save_img('../data/result/pos/right_val_Kind'+"_"+ str(idx) + "_" + str(findKind)+"_"+objinfor+ str(results[2]) + '.jpg', img)
                for i in top_k:
                     print(template.format(labels[i], results[i]))
            except:
                print(objinfor)

        # 输出结果是长度为10(对应0-9)的一维数据,最大值的下标就是预测的数字
        #print('result:{}'.format((np.where(result == np.max(result)))[0][0]))
        # if (result[2] > 0.5):
        #     print ("红头文件")
        # if (result[0] > 0.005):
        #     print ("军服")


    used_time = time.time() - start_time
    print('共记发现:{}个违规'.format(findnum))
    print('used_time:{}'.format(used_time))
    print('model_interpreter_time:{}'.format(model_interpreter_time))
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值