摘要
代码地址:https://github.com/endernewton/tf-faster-rcnn
pb输出节点issues:https://github.com/endernewton/tf-faster-rcnn/issues/340
vgg节点博客:https://www.cnblogs.com/zerotoinfinity/p/10242849.html
trainval_net修改在pycharm运行:http://www.xcx1024.com/ArtInfo/2613074.html
ckpt模型寻找节点
import tensorflow as tf
import os
path = r'G:\cai_op\tf-faster-rcnn\output\vgg16\voc_2007_trainval\default'
with tf.Session() as sess:
# 加载模型定义的graph
saver = tf.train.import_meta_graph(r'G:\cai_op\tf-faster-rcnn\output\vgg16\voc_2007_trainval\default\vgg16_faster_rcnn_iter_5000.ckpt.meta')
# 方式一:加载指定文件夹下最近保存的一个模型的数据
# saver.restore(sess, tf.train.latest_checkpoint('./'))
# 方式二:指定具体某个数据,需要注意的是,指定的文件不要包含后缀
saver.restore(sess, os.path.join(path, 'vgg16_faster_rcnn_iter_5000.ckpt'))
# 查看模型中的trainable variables
tvs = [v for v in tf.trainable_variables()]
for v in tvs:
print(v.name)
print(sess.run(v))
# 查看模型中的所有tensor或者operations
gv = [v for v in tf.global_variables()]
for v in gv:
print(v.name)
# 获得几乎所有的operations相关的tensor
ops = [o for o in sess.graph.get_operations()]
for o in ops:
print(o.name)
ckpt转pb模型1
"""
Demo script showing detections in sample images.
See README.md for installation instructions before running.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from lib.model.config import cfg
from lib.model.test import im_detect
from lib.model.nms_wrapper import nms
from lib.utils.timer import Timer
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import os, cv2
import argparse
from tensorflow.python.tools import freeze_graph
from tensorflow.python.framework import graph_util
from lib.nets.vgg16 import vgg16
from lib.nets.resnet_v1 import resnetv1
CLASSES = ('__background__',
'fire','smoke')
NETS = {
'vgg16': ('vgg16_faster_rcnn_iter_110000.ckpt',),'res101': ('res101_faster_rcnn_iter_110000.ckpt',)}
DATASETS= {
'pascal_voc': ('voc_2007_trainval',),'pascal_voc_0712': ('voc_2007_trainval+voc_2012_trainval',)}
def vis_detections(im, class_name, dets, thresh=0.5):
"""Draw detected bounding boxes."""
inds = np.where(dets[:, -1] >= thresh)[0]
if len(inds) == 0:
return
im = im[:, :, (2, 1, 0)]
fig, ax = plt.subplots(figsize=(12, 12))
ax.imshow(im, aspect='equal')
for i in inds:
bbox = dets[i, :4]
score = dets[i, -1]
ax.add_patch(
plt.Rectangle((bbox[0], bbox[1]),
bbox[2] - bbox[0],
bbox[3] - bbox[1], fill=False,
edgecolor='red', linewidth=3.5)
)
ax.text(bbox[0], bbox[1] - 2,
'{:s} {:.3f}'.format(class_name, score),
bbox=dict(facecolor='blue', alpha=0.5),
fontsize=14, color='white')
ax.set_title(('{} detections with '
'p({} | box) >= {:.1f}').format(class_name, class_name,
thresh),
fontsize=14)
plt.axis('off')
plt.tight_layout()
plt.draw()
def demo(sess, net, image_name):
"""Detect object classes in an image using pre-computed object proposals."""
# Load the demo image
im_file = os.path.join(cfg.DATA_DIR, 'demo', image_name)
im = cv2.imread(im_file)
# Detect all object classes and regress object bounds
timer = Timer()
timer.tic()
scores, boxes = im_detect(sess, net, im)
timer.toc(