【SSD初探】TensorFlow-SSD实现目标识别

准备

利用训练好的模型文件,进行图片测试:

(1)下载源文件:https://github.com/balancap/SSD-Tensorflow

(2)解压后,将下载的 SSD-Tensorflow-master 解压,将“./checkpoints”文件下的“ssd_300_vgg.ckpt”。

最终模型文件的存放位置,就是“./checkpoints/”

测试程序

import tensorflow as tf
import os
import cv2
import numpy as np
import random
import math

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"  # 忽略tensorflow警告信息
slim = tf.contrib.slim

import matplotlib.pyplot as plt
import matplotlib.image as mpimg

import sys
sys.path.append('./')

from nets import ssd_vgg_300, ssd_common,np_methods
from preprocessing import ssd_vgg_preprocessing
from notebooks import visualization

gpu_options = tf.GPUOptions(allow_growth=True)
config = tf.ConfigProto(log_device_placement=False,gpu_options=gpu_options)
isess = tf.InteractiveSession(config=config)

#input
net_shape = (300,300)
data_format = 'NHWC'
img_input = tf.placeholder(tf.uint8,shape=(None,None,3))

image_pre,labels_pre,bboxes_pre,bbox_img = ssd_vgg_preprocessing.preprocess_for_eval(
    img_input,None,None,net_shape,data_format,resize=ssd_vgg_preprocessing.Resize.WARP_RESIZE)
image_4d = tf.expand_dims(image_pre,0)

reuse = True if 'ssd_net' in locals() else None
ssd_net = ssd_vgg_300.SSDNet()
with slim.arg_scope(ssd_net.arg_scope(data_format=data_format)):
    predictions,localisations,_,_ = ssd_net.net(image_4d,is_training=False,reuse=reuse)

#加载ckpt训练文件
ckpt_filename = './checkpoints/ssd_300_vgg.ckpt'
isess.run(tf.global_variables_initializer())
saver  = tf.train.Saver()
saver.restore(isess,ckpt_filename)
ssd_anchors = ssd_net.anchors(net_shape)

def process_image(img, select_threshold=0.5, nms_threshold=.45, net_shape=(300, 300)):
    # Run SSD network.
    rimg, rpredictions, rlocalisations, rbbox_img = isess.run([image_4d, predictions, localisations, bbox_img],
                                                              feed_dict={img_input: img})

    # Get classes and bboxes from the net outputs.
    rclasses, rscores, rbboxes = np_methods.ssd_bboxes_select(
        rpredictions, rlocalisations, ssd_anchors,
        select_threshold=select_threshold, img_shape=net_shape, num_classes=21, decode=True)

    rbboxes = np_methods.bboxes_clip(rbbox_img, rbboxes)
    rclasses, rscores, rbboxes = np_methods.bboxes_sort(rclasses, rscores, rbboxes, top_k=400)
    rclasses, rscores, rbboxes = np_methods.bboxes_nms(rclasses, rscores, rbboxes, nms_threshold=nms_threshold)
    rbboxes = np_methods.bboxes_resize(rbbox_img, rbboxes)
    return rclasses, rscores, rbboxes

img = mpimg.imread('ourLaboratory.jpg') #测试图片的完整路径
rclasses, rscores, rbboxes =  process_image(img)

visualization.plt_bboxes(img, rclasses, rscores, rbboxes)

测试结果

运行上面的测试程序,输入完成测图片路径,下面是测试效果。

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值