其实想要使用别人训练好的模型很简单,确定模型输入输出张量名,跑一下就可以:
import numpy as np
import tensorflow as tf
import cv2 as cv
import os
def main():
folder_path = r'D:\share\samples'
result_path = r'D:\share\test_result'
if not os.path.exists(result_path):
os.mkdir(result_path)
usedlabel = [1, 3, 6, 8, 10, 13]
vehiclelabel = [3, 6, 8]
font = cv.FONT_HERSHEY_SIMPLEX
class_name = ["person", "bicycle", "car", "motorbike", "aeroplane", "bus", "train", "truck"]
model_path = r'D:\share\ssd_resnet50_v1_fpn_shared_box_predictor_640x640_coco14_sync_2018_07_03\frozen_inference_graph.pb'
# Read the graph.
with tf.gfile.FastGFile(
model_path, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
with tf.Session() as sess:
# Restore session
sess.graph.as_default()
tf.import_graph_def(graph_def, name='')
for sub in os.listdir(folder_path):
if not sub.endswith('.jpg'):
continue
img_name = os.path.join(folder_path, sub)
result_name = os.path.join(result_path, sub)
img = cv.imread(img_name)
pad_img = pad_to_square(img, [640, 640])
change_img = pad_img[:, :, [2, 1, 0]] # BGR2RGB
#cv.namedWindow("pad_img")
# Run the model
out = sess.run([sess.graph.get_tensor_by_name('num_detections:0'),
sess.graph.get_tensor_by_name('detection_scores:0'),
sess.graph.get_tensor_by_name('detection_boxes:0'),
sess.graph.get_tensor_by_name('detection_classes:0')],
feed_dict={'image_tensor:0': pad_img.reshape(1, change_img.shape[0], change_img.shape[1], 3)})
# Visualize detected bounding boxes.
num_detections = int(out[0][0])
classlist = []
bboxlist = []
for i in range(num_detections):
classId = int(out[3][0][i])
score = float(out[1][0][i])
bbox = [float(v) for v in out[2][0][i]]
if score < 0.5: # 得分小于此不标
continue
x = bbox[1] * pad_img.shape[0]
y = bbox[0] * pad_img.shape[1]
right = bbox[3] * pad_img.shape[0]
bottom = bbox[2] * pad_img.shape[1]
# if (classId in vehiclelabel) and (right - x < 40 or bottom - y < 40):
# continue
classlist.append(classId)
bboxlist.append([x, y, right, bottom])
assert len(classlist) == len(bboxlist)
for i, box in enumerate(bboxlist):
p1 = (int(box[0]), int(box[1]))
p2 = (int(box[2]), int(box[3]))
if classlist[i] == 1:
cv.rectangle(pad_img, p1, p2, (255, 255, 0), thickness=2)
cv.putText(pad_img, class_name[classlist[i]-1], p1, font, 0.8, (255, 255, 0), 2, False)
elif classlist[i] == 3:
cv.rectangle(pad_img, p1, p2, (0, 0, 255), thickness=2)
cv.putText(pad_img, class_name[classlist[i] - 1], p1, font, 0.8, (0, 0, 255), 2, False)
elif classlist[i] == 6:
cv.rectangle(pad_img, p1, p2, (0, 255, 255), thickness=2)
cv.putText(pad_img, class_name[classlist[i] - 1], p1, font, 0.8, (0, 255, 255), 2, False)
elif classlist[i] == 8:
cv.rectangle(pad_img, p1, p2, (255, 0, 255), thickness=2)
cv.putText(pad_img, class_name[classlist[i] - 1], p1, font, 0.8, (255, 0, 255), 2, False)
else:
pass
cv.imwrite(result_name, pad_img)
if __name__ == '__main__':
main()
读取tensorflow.pb,输出节点名,以便确定输入输出:
import tensorflow as tf gf = tf.GraphDef() gf.ParseFromString(open(r'D:\share\ssd_resnet50_v1_fpn_shared_box_predictor_640x640_coco14_sync_2018_07_03\frozen_inference_graph.pb', 'rb').read()) for n in gf.node: print(n.name + ' ===> ' + n.op)