tf-faster-rcnn模型训练vgg16的ckpy转pb

摘要

代码地址:https://github.com/endernewton/tf-faster-rcnn
pb输出节点issues:https://github.com/endernewton/tf-faster-rcnn/issues/340
vgg节点博客:https://www.cnblogs.com/zerotoinfinity/p/10242849.html
trainval_net修改在pycharm运行:http://www.xcx1024.com/ArtInfo/2613074.html

ckpt模型寻找节点

import tensorflow as tf
import os


path = r'G:\cai_op\tf-faster-rcnn\output\vgg16\voc_2007_trainval\default'
with tf.Session() as sess:
    # 加载模型定义的graph
    saver = tf.train.import_meta_graph(r'G:\cai_op\tf-faster-rcnn\output\vgg16\voc_2007_trainval\default\vgg16_faster_rcnn_iter_5000.ckpt.meta')
    # 方式一:加载指定文件夹下最近保存的一个模型的数据
    # saver.restore(sess, tf.train.latest_checkpoint('./'))
    # 方式二:指定具体某个数据,需要注意的是,指定的文件不要包含后缀
    saver.restore(sess, os.path.join(path, 'vgg16_faster_rcnn_iter_5000.ckpt'))

    # 查看模型中的trainable variables
    tvs = [v for v in tf.trainable_variables()]
    for v in tvs:
        print(v.name)
        print(sess.run(v))

    # 查看模型中的所有tensor或者operations
    gv = [v for v in tf.global_variables()]
    for v in gv:
        print(v.name)

    # 获得几乎所有的operations相关的tensor
    ops = [o for o in sess.graph.get_operations()]
    for o in ops:
        print(o.name)

ckpt转pb模型1

"""
Demo script showing detections in sample images.

See README.md for installation instructions before running.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from lib.model.config import cfg
from lib.model.test import im_detect
from lib.model.nms_wrapper import nms

from lib.utils.timer import Timer
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import os, cv2
import argparse
from tensorflow.python.tools import freeze_graph
from tensorflow.python.framework import graph_util

from lib.nets.vgg16 import vgg16
from lib.nets.resnet_v1 import resnetv1

CLASSES = ('__background__',
           'fire','smoke')

NETS = {
   'vgg16': ('vgg16_faster_rcnn_iter_110000.ckpt',),'res101': ('res101_faster_rcnn_iter_110000.ckpt',)}
DATASETS= {
   'pascal_voc': ('voc_2007_trainval',),'pascal_voc_0712': ('voc_2007_trainval+voc_2012_trainval',)}

def vis_detections(im, class_name, dets, thresh=0.5):
    """Draw detected bounding boxes."""
    inds = np.where(dets[:, -1] >= thresh)[0]
    if len(inds) == 0:
        return

    im = im[:, :, (2, 1, 0)]
    fig, ax = plt.subplots(figsize=(12, 12))
    ax.imshow(im, aspect='equal')
    for i in inds:
        bbox = dets[i, :4]
        score = dets[i, -1]

        ax.add_patch(
            plt.Rectangle((bbox[0], bbox[1]),
                          bbox[2] - bbox[0],
                          bbox[3] - bbox[1], fill=False,
                          edgecolor='red', linewidth=3.5)
            )
        ax.text(bbox[0], bbox[1] - 2,
                '{:s} {:.3f}'.format(class_name, score),
                bbox=dict(facecolor='blue', alpha=0.5),
                fontsize=14, color='white')

    ax.set_title(('{} detections with '
                  'p({} | box) >= {:.1f}').format(class_name, class_name,
                                                  thresh),
                  fontsize=14)
    plt.axis('off')
    plt.tight_layout()
    plt.draw()

def demo(sess, net, image_name):
    """Detect object classes in an image using pre-computed object proposals."""

    # Load the demo image
    im_file = os.path.join(cfg.DATA_DIR, 'demo', image_name)
    im = cv2.imread(im_file)

    # Detect all object classes and regress object bounds
    timer = Timer()
    timer.tic()
    scores, boxes = im_detect(sess, net, im)
    timer.toc(
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值