tensorflow 使用别人的模型

其实想要使用别人训练好的模型很简单,确定模型输入输出张量名,跑一下就可以:

import numpy as np
import tensorflow as tf
import cv2 as cv
import os

def main():
    folder_path = r'D:\share\samples'
    result_path = r'D:\share\test_result'
    if not os.path.exists(result_path):
        os.mkdir(result_path)

    usedlabel = [1, 3, 6, 8, 10, 13]
    vehiclelabel = [3, 6, 8]
    font = cv.FONT_HERSHEY_SIMPLEX

    class_name = ["person", "bicycle", "car", "motorbike", "aeroplane", "bus", "train", "truck"]
    model_path = r'D:\share\ssd_resnet50_v1_fpn_shared_box_predictor_640x640_coco14_sync_2018_07_03\frozen_inference_graph.pb'

    # Read the graph.
    with tf.gfile.FastGFile(
            model_path, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())

    with tf.Session() as sess:
        # Restore session
        sess.graph.as_default()
        tf.import_graph_def(graph_def, name='')

        for sub in os.listdir(folder_path):
            if not sub.endswith('.jpg'):
                continue
            img_name = os.path.join(folder_path, sub)

            result_name = os.path.join(result_path, sub)
            img = cv.imread(img_name)
            pad_img = pad_to_square(img, [640, 640])
            change_img = pad_img[:, :, [2, 1, 0]]  # BGR2RGB
            #cv.namedWindow("pad_img")
            # Run the model
            out = sess.run([sess.graph.get_tensor_by_name('num_detections:0'),
                            sess.graph.get_tensor_by_name('detection_scores:0'),
                            sess.graph.get_tensor_by_name('detection_boxes:0'),
                            sess.graph.get_tensor_by_name('detection_classes:0')],
                           feed_dict={'image_tensor:0': pad_img.reshape(1, change_img.shape[0], change_img.shape[1], 3)})

            # Visualize detected bounding boxes.
            num_detections = int(out[0][0])
            classlist = []
            bboxlist = []

            for i in range(num_detections):
                classId = int(out[3][0][i])
                score = float(out[1][0][i])
                bbox = [float(v) for v in out[2][0][i]]
                if score < 0.5:  # 得分小于此不标
                    continue
                x = bbox[1] * pad_img.shape[0]
                y = bbox[0] * pad_img.shape[1]
                right = bbox[3] * pad_img.shape[0]
                bottom = bbox[2] * pad_img.shape[1]

                # if (classId in vehiclelabel) and (right - x < 40 or bottom - y < 40):
                #     continue

                classlist.append(classId)
                bboxlist.append([x, y, right, bottom])

            assert len(classlist) == len(bboxlist)

            for i, box in enumerate(bboxlist):
                p1 = (int(box[0]), int(box[1]))
                p2 = (int(box[2]), int(box[3]))
                if classlist[i] == 1:
                    cv.rectangle(pad_img, p1, p2, (255, 255, 0), thickness=2)
                    cv.putText(pad_img, class_name[classlist[i]-1], p1, font, 0.8, (255, 255, 0), 2, False)
                elif classlist[i] == 3:
                    cv.rectangle(pad_img, p1, p2, (0, 0, 255), thickness=2)
                    cv.putText(pad_img, class_name[classlist[i] - 1], p1, font, 0.8, (0, 0, 255), 2, False)
                elif classlist[i] == 6:
                    cv.rectangle(pad_img, p1, p2, (0, 255, 255), thickness=2)
                    cv.putText(pad_img, class_name[classlist[i] - 1], p1, font, 0.8, (0, 255, 255), 2, False)
                elif classlist[i] == 8:
                    cv.rectangle(pad_img, p1, p2, (255, 0, 255), thickness=2)
                    cv.putText(pad_img, class_name[classlist[i] - 1], p1, font, 0.8, (255, 0, 255), 2, False)
                else:
                    pass
            cv.imwrite(result_name, pad_img)

if __name__ == '__main__':
    main()

 

读取tensorflow.pb,输出节点名,以便确定输入输出:

import tensorflow as tf

gf = tf.GraphDef()
gf.ParseFromString(open(r'D:\share\ssd_resnet50_v1_fpn_shared_box_predictor_640x640_coco14_sync_2018_07_03\frozen_inference_graph.pb', 'rb').read())

for n in gf.node:
    print(n.name + ' ===> ' + n.op)
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值