# -*- coding:utf-8 -*- import os os.environ["CUDA_VISIBLE_DEVICES"] = "-1" import numpy as np import time from keras.preprocessing import image import tensorflow as tf from keras.applications.imagenet_utils import preprocess_input #test_image_dir = '../data/trainingData/pos' # msoldierDetection/models/retrained_graph_MobileNet.tflite model_path = "../models/retrained_graph_MobileNet.tflite" #model_path = "../models/retrained_graph_RedHeadFile_and_Soldier.tflite" # Load TFLite model and allocate tensors. interpreter = tf.lite.Interpreter(model_path=model_path) interpreter.allocate_tensors() # Get input and output tensors. input_details = interpreter.get_input_details() print(str(input_details)) output_details = interpreter.get_output_details() print(str(output_details)) def load_labels(label_file): label = [] proto_as_ascii_lines = tf.gfile.GFile(label_file).readlines() for l in proto_as_ascii_lines: label.append(l.rstrip()) return label def file_name(file_dir): for root, dirs, files in os.walk(file_dir): # dirs # 当前路径下所有子目录 return root, files # 当前路径下所有非目录子文件 all_imgs = file_name('..\\data\\trainingData\\pos') #all_imgs = file_name('D:\Imagenet2012\ILSVRC2012_img_test') #all_imgs = file_name('../data/wrong/other') #all_imgs = file_name('D:\\working\\Python\\readHeadFileTraining\\Testflower\\tf_files\\flower_photos\\redheadfile') # with tf.Session( ) as sess: if model_path.find("Soldier") != -1: label_file = "../models/retrained_labels.txt" else: label_file = "../models/retrained_labels_1000.txt" labels = load_labels(label_file) if 1: findnum=0 model_interpreter_time = 0 start_time = time.time() # 遍历文件 for idx, img_name in enumerate(all_imgs[1]): #for file in file_list: full_path = all_imgs[0] + '/' + img_name # # 只要黑白的,大小控制在(28,28) # img = cv2.imread(full_path) # res_img = cv2.resize(img, (224, 224)) # # 变成长784的一维数据 # new_img = res_img.reshape(224,224,3) # # # 增加一个维度,变为 [1, 784] # image_np_expanded = np.expand_dims(new_img, axis=0) # image_np_expanded = image_np_expanded.astype('float32') # 类型也要满足要求 #full_path = "../data/wrong/h.jpg" print(str(idx)+' full_path:{}'.format(full_path)) img = image.load_img(full_path, target_size=(224, 224)) x = image.img_to_array(img) x = np.expand_dims(x, axis=0) x = preprocess_input(x,mode='tf') #x = x.astype('float32') # 类型也要满足要求 # 填装数据 model_interpreter_start_time = time.time() interpreter.set_tensor(input_details[0]['index'], x) # 注意注意,我要调用模型了 interpreter.invoke() output_data = interpreter.get_tensor(output_details[0]['index']) model_interpreter_time += time.time() - model_interpreter_start_time # 出来的结果去掉没用的维度 results = np.squeeze(output_data) #print('result:{}'.format(result)) # print('result:{}'.format(sess.run(output, feed_dict={newInput_X: image_np_expanded}))) #print(partition_arg_topK(result, 3 )) top_k = results.argsort()[-5:][::-1] template = "{} (score={:0.5f})" findKind = -1 objinfor='' for i in top_k: #print(template.format(labels[i], results[i])) 注: or or labels[i].find('projectile') != -1 \ labels[i].find('pickelhaube') != -1 \groom, bridegroom,'kimono',labels[i].find('cuirass') != -1 or和服,有可能oror labels[i].find('rifle')!=-1 labels[i].find('vestment') != -1labels[i].find('shield') != -1 oror labels[i].find('helmet')!=-1 \or labels[i].find('prison') != -1 \ if labels[i].find('military uniform')!=-1 \ or labels[i].find('gun')!=-1 \ or labels[i].find('bulletproof')!=-1 \ or labels[i].find('army tank') != -1 \ or labels[i].find('cannon')!=-1 : if results[i]>0.03: findKind = 1 objinfor=objinfor+labels[i]+('%.3f' % results[i]) # if labels[i].find('website')!=-1 or labels[i].find('menu')!=-1 \ # or labels[i].find('screen')!=-1 or labels[i].find('bulletproof')!=-1 \ # or labels[i].find('envelope') != -1 \ # : # findKind = 2 objinfor=objinfor.replace("'","") objinfor = objinfor.replace(":", " ") objinfor = objinfor.replace("\n", "") if findKind is -1: image.save_img('../data/result/neg/wrong_val_' + str(idx) + "_" + str(results[2]) + '.jpg', img) # for i in top_k: # print(template.format(labels[i], results[i])) else: try: findnum = findnum + 1 image.save_img('../data/result/pos/right_val_Kind'+"_"+ str(idx) + "_" + str(findKind)+"_"+objinfor+ str(results[2]) + '.jpg', img) for i in top_k: print(template.format(labels[i], results[i])) except: print(objinfor) # 输出结果是长度为10(对应0-9)的一维数据,最大值的下标就是预测的数字 #print('result:{}'.format((np.where(result == np.max(result)))[0][0])) # if (result[2] > 0.5): # print ("红头文件") # if (result[0] > 0.005): # print ("军服") used_time = time.time() - start_time print('共记发现:{}个违规'.format(findnum)) print('used_time:{}'.format(used_time)) print('model_interpreter_time:{}'.format(model_interpreter_time))
.tflite在PC端的检测程序,python实现
最新推荐文章于 2024-04-22 12:32:28 发布