opencv调用TensorFlow1.x训练的pb模型


前言

通过opencv读取pb文件来进行object detection是一种相对高效的模型部署方式,调用模型快,检测速率也高。

此文章按照本人经验,具体介绍一下如何使用opencv来调用tensorflow1.x生成的pb模型。

此前需要你已经完成训练并生成frozen_inference_graph.pb

OpenCV版本:4.5.5-python-contrib

Tensorflow版本: 1.14.0-gpu


一、搭建虚拟环境

tensorflow1.14+object detection api+python3.7.0虚拟环境搭建

需要用到该python kernel 来生成pbtxt文件。

(既然已经生成了pb文件,应该都有该虚拟环境了)


二、通过frozen_pb来生成pbtxt文件

1.获取opencv官方的转化文件

本人使用的ssd模型,即需要tf_text_graph_common.py和tf_text_graph_ssd.py来转化

OpenCV-GitHub-dnn: GitHub

opencv4.5.5完整包里opencv-4.5.5\samples\dnn百度云盘提取码:ev1t ———或者第三方软件网站

在这里插入图片描述


无法上GitHub或者下载速度慢的直接创建py文件:

tf_text_graph_common.py

def tokenize(s):
    tokens = []
    token = ""
    isString = False
    isComment = False
    for symbol in s:
        isComment = (isComment and symbol != '\n') or (not isString and symbol == '#')
        if isComment:
            continue

        if symbol == ' ' or symbol == '\t' or symbol == '\r' or symbol == '\'' or \
           symbol == '\n' or symbol == ':' or symbol == '\"' or symbol == ';' or \
           symbol == ',':

            if (symbol == '\"' or symbol == '\'') and isString:
                tokens.append(token)
                token = ""
            else:
                if isString:
                    token += symbol
                elif token:
                    tokens.append(token)
                    token = ""
            isString = (symbol == '\"' or symbol == '\'') ^ isString

        elif symbol == '{' or symbol == '}' or symbol == '[' or symbol == ']':
            if token:
                tokens.append(token)
                token = ""
            tokens.append(symbol)
        else:
            token += symbol
    if token:
        tokens.append(token)
    return tokens


def parseMessage(tokens, idx):
    msg = {}
    assert(tokens[idx] == '{')

    isArray = False
    while True:
        if not isArray:
            idx += 1
            if idx < len(tokens):
                fieldName = tokens[idx]
            else:
                return None
            if fieldName == '}':
                break

        idx += 1
        fieldValue = tokens[idx]

        if fieldValue == '{':
            embeddedMsg, idx = parseMessage(tokens, idx)
            if fieldName in msg:
                msg[fieldName].append(embeddedMsg)
            else:
                msg[fieldName] = [embeddedMsg]
        elif fieldValue == '[':
            isArray = True
        elif fieldValue == ']':
            isArray = False
        else:
            if fieldName in msg:
                msg[fieldName].append(fieldValue)
            else:
                msg[fieldName] = [fieldValue]
    return msg, idx


def readTextMessage(filePath):
    if not filePath:
        return {}
    with open(filePath, 'rt') as f:
        content = f.read()

    tokens = tokenize('{' + content + '}')
    msg = parseMessage(tokens, 0)
    return msg[0] if msg else {}


def listToTensor(values):
    if all([isinstance(v, float) for v in values]):
        dtype = 'DT_FLOAT'
        field = 'float_val'
    elif all([isinstance(v, int) for v in values]):
        dtype = 'DT_INT32'
        field = 'int_val'
    else:
        raise Exception('Wrong values types')

    msg = {
        'tensor': {
            'dtype': dtype,
            'tensor_shape': {
                'dim': {
                    'size': len(values)
                }
            }
        }
    }
    msg['tensor'][field] = values
    return msg


def addConstNode(name, values, graph_def):
    node = NodeDef()
    node.name = name
    node.op = 'Const'
    node.addAttr('value', values)
    graph_def.node.extend([node])


def addSlice(inp, out, begins, sizes, graph_def):
    beginsNode = NodeDef()
    beginsNode.name = out + '/begins'
    beginsNode.op = 'Const'
    beginsNode.addAttr('value', begins)
    graph_def.node.extend([beginsNode])

    sizesNode = NodeDef()
    sizesNode.name = out + '/sizes'
    sizesNode.op = 'Const'
    sizesNode.addAttr('value', sizes)
    graph_def.node.extend([sizesNode])

    sliced = NodeDef()
    sliced.name = out
    sliced.op = 'Slice'
    sliced.input.append(inp)
    sliced.input.append(beginsNode.name)
    sliced.input.append(sizesNode.name)
    graph_def.node.extend([sliced])


def addReshape(inp, out, shape, graph_def):
    shapeNode = NodeDef()
    shapeNode.name = out + '/shape'
    shapeNode.op = 'Const'
    shapeNode.addAttr('value', shape)
    graph_def.node.extend([shapeNode])

    reshape = NodeDef()
    reshape.name = out
    reshape.op = 'Reshape'
    reshape.input.append(inp)
    reshape.input.append(shapeNode.name)
    graph_def.node.extend([reshape])


def addSoftMax(inp, out, graph_def):
    softmax = NodeDef()
    softmax.name = out
    softmax.op = 'Softmax'
    softmax.addAttr('axis', -1)
    softmax.input.append(inp)
    graph_def.node.extend([softmax])


def addFlatten(inp, out, graph_def):
    flatten = NodeDef()
    flatten.name = out
    flatten.op = 'Flatten'
    flatten.input.append(inp)
    graph_def.node.extend([flatten])


class NodeDef:
    def __init__(self):
        self.input = []
        self.name = ""
        self.op = ""
        self.attr = {}

    def addAttr(self, key, value):
        assert(not key in self.attr)
        if isinstance(value, bool):
            self.attr[key] = {'b': value}
        elif isinstance(value, int):
            self.attr[key] = {'i': value}
        elif isinstance(value, float):
            self.attr[key] = {'f': value}
        elif isinstance(value, str):
            self.attr[key] = {'s': value}
        elif isinstance(value, list):
            self.attr[key] = listToTensor(value)
        else:
            raise Exception('Unknown type of attribute ' + key)

    def Clear(self):
        self.input = []
        self.name = ""
        self.op = ""
        self.attr = {}


class GraphDef:
    def __init__(self):
        self.node = []

    def save(self, filePath):
        with open(filePath, 'wt') as f:

            def printAttr(d, indent):
                indent = ' ' * indent
                for key, value in sorted(d.items(), key=lambda x:x[0].lower()):
                    value = value if isinstance(value, list) else [value]
                    for v in value:
                        if isinstance(v, dict):
                            f.write(indent + key + ' {\n')
                            printAttr(v, len(indent) + 2)
                            f.write(indent + '}\n')
                        else:
                            isString = False
                            if isinstance(v, str) and not v.startswith('DT_'):
                                try:
                                    float(v)
                                except:
                                    isString = True

                            if isinstance(v, bool):
                                printed = 'true' if v else 'false'
                            elif v == 'true' or v == 'false':
                                printed = 'true' if v == 'true' else 'false'
                            elif isString:
                                printed = '\"%s\"' % v
                            else:
                                printed = str(v)
                            f.write(indent + key + ': ' + printed + '\n')

            for node in self.node:
                f.write('node {\n')
                f.write('  name: \"%s\"\n' % node.name)
                f.write('  op: \"%s\"\n' % node.op)
                for inp in node.input:
                    f.write('  input: \"%s\"\n' % inp)
                for key, value in sorted(node.attr.items(), key=lambda x:x[0].lower()):
                    f.write('  attr {\n')
                    f.write('    key: \"%s\"\n' % key)
                    f.write('    value {\n')
                    printAttr(value, 6)
                    f.write('    }\n')
                    f.write('  }\n')
                f.write('}\n')


def parseTextGraph(filePath):
    msg = readTextMessage(filePath)

    graph = GraphDef()
    for node in msg['node']:
        graphNode = NodeDef()
        graphNode.name = node['name'][0]
        graphNode.op = node['op'][0]
        graphNode.input = node['input'] if 'input' in node else []

        if 'attr' in node:
            for attr in node['attr']:
                graphNode.attr[attr['key'][0]] = attr['value'][0]

        graph.node.append(graphNode)
    return graph


# Removes Identity nodes
def removeIdentity(graph_def):
    identities = {}
    for node in graph_def.node:
        if node.op == 'Identity' or node.op == 'IdentityN':
            inp = node.input[0]
            if inp in identities:
                identities[node.name] = identities[inp]
            else:
                identities[node.name] = inp
            graph_def.node.remove(node)

    for node in graph_def.node:
        for i in range(len(node.input)):
            if node.input[i] in identities:
                node.input[i] = identities[node.input[i]]


def removeUnusedNodesAndAttrs(to_remove, graph_def):
    unusedAttrs = ['T', 'Tshape', 'N', 'Tidx', 'Tdim', 'use_cudnn_on_gpu',
                   'Index', 'Tperm', 'is_training', 'Tpaddings']

    removedNodes = []

    for i in reversed(range(len(graph_def.node))):
        op = graph_def.node[i].op
        name = graph_def.node[i].name

        if to_remove(name, op):
            if op != 'Const':
                removedNodes.append(name)

            del graph_def.node[i]
        else:
            for attr in unusedAttrs:
                if attr in graph_def.node[i].attr:
                    del graph_def.node[i].attr[attr]

    # Remove references to removed nodes except Const nodes.
    for node in graph_def.node:
        for i in reversed(range(len(node.input))):
            if node.input[i] in removedNodes:
                del node.input[i]


def writeTextGraph(modelPath, outputPath, outNodes):
    try:
        import cv2 as cv

        cv.dnn.writeTextGraph(modelPath, outputPath)
    except:
        import tensorflow as tf
        from tensorflow.tools.graph_transforms import TransformGraph

        with tf.gfile.FastGFile(modelPath, 'rb') as f:
            graph_def = tf.compat.v1.GraphDef()
            graph_def.ParseFromString(f.read())

            graph_def = TransformGraph(graph_def, ['image_tensor'], outNodes, ['sort_by_execution_order'])

            for node in graph_def.node:
                if node.op == 'Const':
                    if 'value' in node.attr and node.attr['value'].tensor.tensor_content:
                        node.attr['value'].tensor.tensor_content = b''

        tf.io.write_graph(graph_def, "", outputPath, as_text=True)


tf_text_graph_ssd.py

# This file is a part of OpenCV project.
# It is a subject to the license terms in the LICENSE file found in the top-level directory
# of this distribution and at http://opencv.org/license.html.
#
# Copyright (C) 2018, Intel Corporation, all rights reserved.
# Third party copyrights are property of their respective owners.
#
# Use this script to get the text graph representation (.pbtxt) of SSD-based
# deep learning network trained in TensorFlow Object Detection API.
# Then you can import it with a binary frozen graph (.pb) using readNetFromTensorflow() function.
# See details and examples on the following wiki page: https://github.com/opencv/opencv/wiki/TensorFlow-Object-Detection-API
import argparse
import re
from math import sqrt
from tf_text_graph_common import *

class SSDAnchorGenerator:
    def __init__(self, min_scale, max_scale, num_layers, aspect_ratios,
                 reduce_boxes_in_lowest_layer, image_width, image_height):
        self.min_scale = min_scale
        self.aspect_ratios = aspect_ratios
        self.reduce_boxes_in_lowest_layer = reduce_boxes_in_lowest_layer
        self.image_width = image_width
        self.image_height = image_height
        self.scales =  [min_scale + (max_scale - min_scale) * i / (num_layers - 1)
                            for i in range(num_layers)] + [1.0]

    def get(self, layer_id):
        if layer_id == 0 and self.reduce_boxes_in_lowest_layer:
            widths = [0.1, self.min_scale * sqrt(2.0), self.min_scale * sqrt(0.5)]
            heights = [0.1, self.min_scale / sqrt(2.0), self.min_scale / sqrt(0.5)]
        else:
            widths = [self.scales[layer_id] * sqrt(ar) for ar in self.aspect_ratios]
            heights = [self.scales[layer_id] / sqrt(ar) for ar in self.aspect_ratios]

            widths += [sqrt(self.scales[layer_id] * self.scales[layer_id + 1])]
            heights += [sqrt(self.scales[layer_id] * self.scales[layer_id + 1])]
        min_size = min(self.image_width, self.image_height)
        widths = [w * min_size for w in widths]
        heights = [h * min_size for h in heights]
        return widths, heights


class MultiscaleAnchorGenerator:
    def __init__(self, min_level, aspect_ratios, scales_per_octave, anchor_scale):
        self.min_level = min_level
        self.aspect_ratios = aspect_ratios
        self.anchor_scale = anchor_scale
        self.scales = [2**(float(s) / scales_per_octave) for s in range(scales_per_octave)]

    def get(self, layer_id):
        widths = []
        heights = []
        for a in self.aspect_ratios:
            for s in self.scales:
                base_anchor_size = 2**(self.min_level + layer_id) * self.anchor_scale
                ar = sqrt(a)
                heights.append(base_anchor_size * s / ar)
                widths.append(base_anchor_size * s * ar)
        return widths, heights


def createSSDGraph(modelPath, configPath, outputPath):
    # Nodes that should be kept.
    keepOps = ['Conv2D', 'BiasAdd', 'Add', 'AddV2', 'Relu', 'Relu6', 'Placeholder', 'FusedBatchNorm',
               'DepthwiseConv2dNative', 'ConcatV2', 'Mul', 'MaxPool', 'AvgPool', 'Identity',
               'Sub', 'ResizeNearestNeighbor', 'Pad', 'FusedBatchNormV3', 'Mean']

    # Node with which prefixes should be removed
    prefixesToRemove = ('MultipleGridAnchorGenerator/', 'Concatenate/', 'Postprocessor/', 'Preprocessor/map')

    # Load a config file.
    config = readTextMessage(configPath)
    config = config['model'][0]['ssd'][0]
    num_classes = int(config['num_classes'][0])

    fixed_shape_resizer = config['image_resizer'][0]['fixed_shape_resizer'][0]
    image_width = int(fixed_shape_resizer['width'][0])
    image_height = int(fixed_shape_resizer['height'][0])

    box_predictor = 'convolutional' if 'convolutional_box_predictor' in config['box_predictor'][0] else 'weight_shared_convolutional'

    anchor_generator = config['anchor_generator'][0]
    if 'ssd_anchor_generator' in anchor_generator:
        ssd_anchor_generator = anchor_generator['ssd_anchor_generator'][0]
        min_scale = float(ssd_anchor_generator['min_scale'][0])
        max_scale = float(ssd_anchor_generator['max_scale'][0])
        num_layers = int(ssd_anchor_generator['num_layers'][0])
        aspect_ratios = [float(ar) for ar in ssd_anchor_generator['aspect_ratios']]
        reduce_boxes_in_lowest_layer = True
        if 'reduce_boxes_in_lowest_layer' in ssd_anchor_generator:
            reduce_boxes_in_lowest_layer = ssd_anchor_generator['reduce_boxes_in_lowest_layer'][0] == 'true'
        priors_generator = SSDAnchorGenerator(min_scale, max_scale, num_layers,
                                              aspect_ratios, reduce_boxes_in_lowest_layer,
                                              image_width, image_height)


        print('Scale: [%f-%f]' % (min_scale, max_scale))
        print('Aspect ratios: %s' % str(aspect_ratios))
        print('Reduce boxes in the lowest layer: %s' % str(reduce_boxes_in_lowest_layer))
    elif 'multiscale_anchor_generator' in anchor_generator:
        multiscale_anchor_generator = anchor_generator['multiscale_anchor_generator'][0]
        min_level = int(multiscale_anchor_generator['min_level'][0])
        max_level = int(multiscale_anchor_generator['max_level'][0])
        anchor_scale = float(multiscale_anchor_generator['anchor_scale'][0])
        aspect_ratios = [float(ar) for ar in multiscale_anchor_generator['aspect_ratios']]
        scales_per_octave = int(multiscale_anchor_generator['scales_per_octave'][0])
        num_layers = max_level - min_level + 1
        priors_generator = MultiscaleAnchorGenerator(min_level, aspect_ratios,
                                                     scales_per_octave, anchor_scale)
        print('Levels: [%d-%d]' % (min_level, max_level))
        print('Anchor scale: %f' % anchor_scale)
        print('Scales per octave: %d' % scales_per_octave)
        print('Aspect ratios: %s' % str(aspect_ratios))
    else:
        print('Unknown anchor_generator')
        exit(0)

    print('Number of classes: %d' % num_classes)
    print('Number of layers: %d' % num_layers)
    print('box predictor: %s' % box_predictor)
    print('Input image size: %dx%d' % (image_width, image_height))

    # Read the graph.
    outNames = ['num_detections', 'detection_scores', 'detection_boxes', 'detection_classes']

    writeTextGraph(modelPath, outputPath, outNames)
    graph_def = parseTextGraph(outputPath)

    def getUnconnectedNodes():
        unconnected = []
        for node in graph_def.node:
            unconnected.append(node.name)
            for inp in node.input:
                if inp in unconnected:
                    unconnected.remove(inp)
        return unconnected


    def fuse_nodes(nodesToKeep):
        # Detect unfused batch normalization nodes and fuse them.
        # Add_0 <-- moving_variance, add_y
        # Rsqrt <-- Add_0
        # Mul_0 <-- Rsqrt, gamma
        # Mul_1 <-- input, Mul_0
        # Mul_2 <-- moving_mean, Mul_0
        # Sub_0 <-- beta, Mul_2
        # Add_1 <-- Mul_1, Sub_0
        nodesMap = {node.name: node for node in graph_def.node}
        subgraphBatchNorm = ['Add',
            ['Mul', 'input', ['Mul', ['Rsqrt', ['Add', 'moving_variance', 'add_y']], 'gamma']],
            ['Sub', 'beta', ['Mul', 'moving_mean', 'Mul_0']]]
        subgraphBatchNormV2 = ['AddV2',
            ['Mul', 'input', ['Mul', ['Rsqrt', ['AddV2', 'moving_variance', 'add_y']], 'gamma']],
            ['Sub', 'beta', ['Mul', 'moving_mean', 'Mul_0']]]
        # Detect unfused nearest neighbor resize.
        subgraphResizeNN = ['Reshape',
            ['Mul', ['Reshape', 'input', ['Pack', 'shape_1', 'shape_2', 'shape_3', 'shape_4', 'shape_5']],
                    'ones'],
            ['Pack', ['StridedSlice', ['Shape', 'input'], 'stack', 'stack_1', 'stack_2'],
                     'out_height', 'out_width', 'out_channels']]
        def checkSubgraph(node, targetNode, inputs, fusedNodes):
            op = targetNode[0]
            if node.op == op and (len(node.input) >= len(targetNode) - 1):
                fusedNodes.append(node)
                for i, inpOp in enumerate(targetNode[1:]):
                    if isinstance(inpOp, list):
                        if not node.input[i] in nodesMap or \
                           not checkSubgraph(nodesMap[node.input[i]], inpOp, inputs, fusedNodes):
                            return False
                    else:
                        inputs[inpOp] = node.input[i]

                return True
            else:
                return False

        nodesToRemove = []
        for node in graph_def.node:
            inputs = {}
            fusedNodes = []
            if checkSubgraph(node, subgraphBatchNorm, inputs, fusedNodes) or \
               checkSubgraph(node, subgraphBatchNormV2, inputs, fusedNodes):
                name = node.name
                node.Clear()
                node.name = name
                node.op = 'FusedBatchNorm'
                node.input.append(inputs['input'])
                node.input.append(inputs['gamma'])
                node.input.append(inputs['beta'])
                node.input.append(inputs['moving_mean'])
                node.input.append(inputs['moving_variance'])
                node.addAttr('epsilon', 0.001)
                nodesToRemove += fusedNodes[1:]

            inputs = {}
            fusedNodes = []
            if checkSubgraph(node, subgraphResizeNN, inputs, fusedNodes):
                name = node.name
                node.Clear()
                node.name = name
                node.op = 'ResizeNearestNeighbor'
                node.input.append(inputs['input'])
                node.input.append(name + '/output_shape')

                out_height_node = nodesMap[inputs['out_height']]
                out_width_node = nodesMap[inputs['out_width']]
                out_height = int(out_height_node.attr['value']['tensor'][0]['int_val'][0])
                out_width = int(out_width_node.attr['value']['tensor'][0]['int_val'][0])

                shapeNode = NodeDef()
                shapeNode.name = name + '/output_shape'
                shapeNode.op = 'Const'
                shapeNode.addAttr('value', [out_height, out_width])
                graph_def.node.insert(graph_def.node.index(node), shapeNode)
                nodesToKeep.append(shapeNode.name)

                nodesToRemove += fusedNodes[1:]
        for node in nodesToRemove:
            graph_def.node.remove(node)

    nodesToKeep = []
    fuse_nodes(nodesToKeep)

    removeIdentity(graph_def)

    def to_remove(name, op):
        return (not name in nodesToKeep) and \
               (op == 'Const' or (not op in keepOps) or name.startswith(prefixesToRemove))

    removeUnusedNodesAndAttrs(to_remove, graph_def)


    # Connect input node to the first layer
    assert(graph_def.node[0].op == 'Placeholder')
    try:
        input_shape = graph_def.node[0].attr['shape']['shape'][0]['dim']
        input_shape[1]['size'] = image_height
        input_shape[2]['size'] = image_width
    except:
        print("Input shapes are undefined")
    # assert(graph_def.node[1].op == 'Conv2D')
    weights = graph_def.node[1].input[-1]
    for i in range(len(graph_def.node[1].input)):
        graph_def.node[1].input.pop()
    graph_def.node[1].input.append(graph_def.node[0].name)
    graph_def.node[1].input.append(weights)

    # check and correct the case when preprocessing block is after input
    preproc_id = "Preprocessor/"
    if graph_def.node[2].name.startswith(preproc_id) and \
        graph_def.node[2].input[0].startswith(preproc_id):

        if not any(preproc_id in inp for inp in graph_def.node[3].input):
            graph_def.node[3].input.insert(0, graph_def.node[2].name)


    # Create SSD postprocessing head ###############################################

    # Concatenate predictions of classes, predictions of bounding boxes and proposals.
    def addConcatNode(name, inputs, axisNodeName):
        concat = NodeDef()
        concat.name = name
        concat.op = 'ConcatV2'
        for inp in inputs:
            concat.input.append(inp)
        concat.input.append(axisNodeName)
        graph_def.node.extend([concat])

    addConstNode('concat/axis_flatten', [-1], graph_def)
    addConstNode('PriorBox/concat/axis', [-2], graph_def)

    for label in ['ClassPredictor', 'BoxEncodingPredictor' if box_predictor == 'convolutional' else 'BoxPredictor']:
        concatInputs = []
        for i in range(num_layers):
            # Flatten predictions
            flatten = NodeDef()
            if box_predictor == 'convolutional':
                inpName = 'BoxPredictor_%d/%s/BiasAdd' % (i, label)
            else:
                if i == 0:
                    inpName = 'WeightSharedConvolutionalBoxPredictor/%s/BiasAdd' % label
                else:
                    inpName = 'WeightSharedConvolutionalBoxPredictor_%d/%s/BiasAdd' % (i, label)
            flatten.input.append(inpName)
            flatten.name = inpName + '/Flatten'
            flatten.op = 'Flatten'

            concatInputs.append(flatten.name)
            graph_def.node.extend([flatten])
        addConcatNode('%s/concat' % label, concatInputs, 'concat/axis_flatten')

    num_matched_layers = 0
    for node in graph_def.node:
        if re.match('BoxPredictor_\d/BoxEncodingPredictor/convolution', node.name) or \
           re.match('BoxPredictor_\d/BoxEncodingPredictor/Conv2D', node.name) or \
           re.match('WeightSharedConvolutionalBoxPredictor(_\d)*/BoxPredictor/Conv2D', node.name):
            node.addAttr('loc_pred_transposed', True)
            num_matched_layers += 1
    assert(num_matched_layers == num_layers)

    # Add layers that generate anchors (bounding boxes proposals).
    priorBoxes = []
    boxCoder = config['box_coder'][0]
    fasterRcnnBoxCoder = boxCoder['faster_rcnn_box_coder'][0]
    boxCoderVariance = [1.0/float(fasterRcnnBoxCoder['x_scale'][0]), 1.0/float(fasterRcnnBoxCoder['y_scale'][0]), 1.0/float(fasterRcnnBoxCoder['width_scale'][0]), 1.0/float(fasterRcnnBoxCoder['height_scale'][0])]
    for i in range(num_layers):
        priorBox = NodeDef()
        priorBox.name = 'PriorBox_%d' % i
        priorBox.op = 'PriorBox'
        if box_predictor == 'convolutional':
            priorBox.input.append('BoxPredictor_%d/BoxEncodingPredictor/BiasAdd' % i)
        else:
            if i == 0:
                priorBox.input.append('WeightSharedConvolutionalBoxPredictor/BoxPredictor/Conv2D')
            else:
                priorBox.input.append('WeightSharedConvolutionalBoxPredictor_%d/BoxPredictor/BiasAdd' % i)
        priorBox.input.append(graph_def.node[0].name)  # image_tensor

        priorBox.addAttr('flip', False)
        priorBox.addAttr('clip', False)

        widths, heights = priors_generator.get(i)

        priorBox.addAttr('width', widths)
        priorBox.addAttr('height', heights)
        priorBox.addAttr('variance', boxCoderVariance)

        graph_def.node.extend([priorBox])
        priorBoxes.append(priorBox.name)

    # Compare this layer's output with Postprocessor/Reshape
    addConcatNode('PriorBox/concat', priorBoxes, 'concat/axis_flatten')

    # Sigmoid for classes predictions and DetectionOutput layer
    addReshape('ClassPredictor/concat', 'ClassPredictor/concat3d', [0, -1, num_classes + 1], graph_def)

    sigmoid = NodeDef()
    sigmoid.name = 'ClassPredictor/concat/sigmoid'
    sigmoid.op = 'Sigmoid'
    sigmoid.input.append('ClassPredictor/concat3d')
    graph_def.node.extend([sigmoid])

    addFlatten(sigmoid.name, sigmoid.name + '/Flatten', graph_def)

    detectionOut = NodeDef()
    detectionOut.name = 'detection_out'
    detectionOut.op = 'DetectionOutput'

    if box_predictor == 'convolutional':
        detectionOut.input.append('BoxEncodingPredictor/concat')
    else:
        detectionOut.input.append('BoxPredictor/concat')
    detectionOut.input.append(sigmoid.name + '/Flatten')
    detectionOut.input.append('PriorBox/concat')

    detectionOut.addAttr('num_classes', num_classes + 1)
    detectionOut.addAttr('share_location', True)
    detectionOut.addAttr('background_label_id', 0)

    postProcessing = config['post_processing'][0]
    batchNMS = postProcessing['batch_non_max_suppression'][0]

    if 'iou_threshold' in batchNMS:
        detectionOut.addAttr('nms_threshold', float(batchNMS['iou_threshold'][0]))
    else:
        detectionOut.addAttr('nms_threshold', 0.6)

    if 'score_threshold' in batchNMS:
        detectionOut.addAttr('confidence_threshold', float(batchNMS['score_threshold'][0]))
    else:
        detectionOut.addAttr('confidence_threshold', 0.01)

    if 'max_detections_per_class' in batchNMS:
        detectionOut.addAttr('top_k', int(batchNMS['max_detections_per_class'][0]))
    else:
        detectionOut.addAttr('top_k', 100)

    if 'max_total_detections' in batchNMS:
        detectionOut.addAttr('keep_top_k', int(batchNMS['max_total_detections'][0]))
    else:
        detectionOut.addAttr('keep_top_k', 100)

    detectionOut.addAttr('code_type', "CENTER_SIZE")

    graph_def.node.extend([detectionOut])

    while True:
        unconnectedNodes = getUnconnectedNodes()
        unconnectedNodes.remove(detectionOut.name)
        if not unconnectedNodes:
            break

        for name in unconnectedNodes:
            for i in range(len(graph_def.node)):
                if graph_def.node[i].name == name:
                    del graph_def.node[i]
                    break

    # Save as text.
    graph_def.save(outputPath)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Run this script to get a text graph of '
                                                 'SSD model from TensorFlow Object Detection API. '
                                                 'Then pass it with .pb file to cv::dnn::readNetFromTensorflow function.')
    parser.add_argument('--input', required=True, help='Path to frozen TensorFlow graph.')
    parser.add_argument('--output', required=True, help='Path to output text graph.')
    parser.add_argument('--config', required=True, help='Path to a *.config file is used for training.')
    args = parser.parse_args()

    createSSDGraph(args.input, args.config, args.output)


2.生成pbtxt文件

若是使用百度云盘下载opencv则会有以下目录:

在这里插入图片描述


然后在Terminal中(确保是tensorflow1.x版本)

cd到tf_text_graph_ssd.py的上级路径

输入(最好是输入绝对路径,修改为你pipeline和frozen_inference_graph.pb以及要输出的graph.pbtxt的路径):

python tf_text_graph_ssd.py --input E:\opencv-4.5.5\samples\dnn\test\frozen_inference_graph.pb  --output E:\opencv-4.5.5\samples\dnn\test\frozen_inference_graph.pbtxt  --config E:\opencv-4.5.5\samples\dnn\test\pipeline.config

在这里插入图片描述


成功生成后内容显示如下,并不会有其它的提示,且pbtxt文档不会太大(一般1mb以下):

在这里插入图片描述

如此一来便有以下三个文件:
在这里插入图片描述

三、OpenCV调用pb模型

在上述目录下创建opencv_test.py:

import cv2
import time


# 根据你所写的分类文件labels_items.pbtxt里的信息修改classNames
# id与标签一一对应
classNames = {0: 'AAA', 1: 'S07_B_normal', 2: 'S07_B_1', 3: 'S07_B_2', 4: 'S07_B_3', 5: 'S07_B_4', 6: 'S07_B_5',
              7: 'XXX', 8: 'YYY', 9: 'ZZZ'}

# 修改成对应路径
pb_path = r'E:\opencv-4.5.5\samples\dnn\test\frozen_inference_graph.pb'
pbtxt_path = r'E:\opencv-4.5.5\samples\dnn\test\frozen_inference_graph.pbtxt'
img_path = r'E:\opencv-4.5.5\samples\dnn\test\001_1.jpg'


def id_class_name(class_id, classes):
    for key, value in classes.items():
        if class_id == key:
            return value


# 获取计算时间
t1 = time.time()

model = cv2.dnn.readNetFromTensorflow(pb_path, pbtxt_path)
image = cv2.imread(img_path)

image_height, image_width, _ = image.shape

model.setInput(cv2.dnn.blobFromImage(image, size=(300, 300), swapRB=True))
output = model.forward()

for detection in output[0, 0, :, :]:
    confidence = detection[2]
    if confidence > .5:
        class_id = detection[1]
        class_name = id_class_name(class_id, classNames)
        print(str(str(class_id) + " " + str(detection[2]) + " " + class_name))

        # 目标位置
        box_x = detection[3] * image_width
        box_y = detection[4] * image_height
        box_width = detection[5] * image_width
        box_height = detection[6] * image_height
        print(box_x, box_y, box_width, box_height)
        cv2.rectangle(image, (int(box_x), int(box_y)), (int(box_width), int(box_height)), (0, 255, 0), thickness=2)
        cv2.putText(image, class_name, (int(box_x), int(box_y + .05 * image_height)), cv2.FONT_HERSHEY_SIMPLEX,
                    (.002 * image_width), (0, 0, 255))

cv2.imshow('image', image)
# cv2.imwrite("image_box_text.jpg",image)

t2 = time.time()
print('time:', t2 - t1)

cv2.waitKey(0)
cv2.destroyAllWindows()

输出结果:
在这里插入图片描述

  • 0
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值