#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Author: Marvin Teichmann
"""
Classify an image using KittiClass.
Input: Image
Output: Image (with Cars plotted in Green)
Utilizes: Trained KittiClass weights. If no logdir is given,
pretrained weights will be downloaded and used.
Usage:
python demo.py --input data/demo.png [--output output]
[--logdir /path/to/weights] [--gpus 0]
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
import logging
import os
import sys
import collections
# configure logging
#定义日志,记录关键节点信息
logging.basicConfig(format='%(asctime)s %(levelname)s %(message)s',
level=logging.INFO,
stream=sys.stdout)
# https://github.com/tensorflow/tensorflow/issues/2034#issuecomment-220820070
import numpy as np
import scipy as scp
import scipy.misc
import tensorflow as tf
flags = tf.app.flags #用于接受命令行传递参数,处理命令行参数的解析工作,(JSON)
FLAGS = flags.FLAGS #构造一个解析器FLAGS
sys.path.insert(1, 'incl')#定义搜索的优先顺序
#新添加的目录会优先于其他目录被import检查
#我的项目中没有这个文件incl,所以没有用到,见65行
try:
# Check whether setup was done correctly
import tensorvision.utils as tv_utils
import tensorvision.core as core
except ImportError:
# You forgot to initialize submodules
logging.error("Could not import the submodules.")
logging.error("Please execute:"
"'git submodule update --init --recursive'")
exit(1)
#原本程序中是None,我没有incl文件说以人为指定第二个参数
flags.DEFINE_string('logdir', 'logdir',
'Path to logdir.')
#tf.app.flags.DEFINE_string() 定义一个用于接收string 类型数值的变量
#三个参数分别是:变量名称、默认值、用法描述,
#cjh-2019-4-16
flags.DEFINE_string('input', 'DATA/demo/007007.png',
'Image to apply KittiSeg.')
flags.DEFINE_string('output', 'DATA/demo/007007_.png',
'Image to apply KittiSeg.')
default_run = 'KittiClass_postpaper' #文件夹的名字
#可能是加载欲训练网络的权重,
weights_url = ("ftp://mi.eng.cam.ac.uk/"
"pub/mttt2/models/KittiClass_postpaper.zip")
#路径为ftp://mi.eng.cam.ac.uk/pub/mttt2/models/
#自行手动下载
from PIL import Image, ImageDraw, ImageFont
#在图片上画结果
def road_draw(image, highway):
im = Image.fromarray(image.astype('uint8'))
draw = ImageDraw.Draw(im)
#fnt = ImageFont.truetype('FreeMono/FreeMonoBold.ttf', 40)
fnt = ImageFont.truetype('simhei.ttf', 40)
shape = image.shape
if highway:
draw.text((65, 10), "Highway",
font=fnt, fill=(255, 255, 0, 255))
draw.ellipse([10, 10, 55, 55], fill=(255, 255, 0, 255),
outline=(255, 255, 0, 255))
else:
draw.text((65, 10), "small road",
font=fnt, fill=(255, 0, 0, 255))
draw.ellipse([10, 10, 55, 55], fill=(255, 0, 0, 255),
outline=(255, 0, 0, 255))
return np.array(im).astype('float32')
#runs_dir为输入值,目的是构造路径,整个函数的目的就是下载权重,
# 然后解压zip文件,如果存在则不进行任何操作。
def maybe_download_and_extract(runs_dir):
logdir = os.path.join(runs_dir, default_run)#构造路径
if os.path.exists(logdir):
# weights are downloaded. Nothing to do
return
#解压文件
import zipfile
download_name = tv_utils.download(weights_url, runs_dir)#下载
logging.info("Extracting KittiSeg_pretrained.zip")
zipfile.ZipFile(download_name, 'r').extractall(runs_dir)
return
#重新定义图片的大小,第一个是原图,第二个是标记图像
def resize_label_image(image, gt_image, image_height, image_width):
image = scp.misc.imresize(image, size=(image_height, image_width),
interp='cubic')
shape = gt_image.shape
gt_image = scp.misc.imresize(gt_image, size=(image_height, image_width),
interp='nearest')
return image, gt_image
def main(_):
##设置运行代码的GPU
tv_utils.set_gpus_to_use()
if FLAGS.input is None:
logging.error("No input was given.")
logging.info(
"Usage: python demo.py --input data/test.png "
"[--output output] [--logdir /path/to/weights] "
"[--gpus GPUs_to_use] ")
exit(1)
if FLAGS.logdir is None:
# Download and use weights from the MultiNet Paper
if 'TV_DIR_RUNS' in os.environ:
runs_dir = os.path.join(os.environ['TV_DIR_RUNS'],
'KittiClass')
else:
runs_dir = 'RUNS'
maybe_download_and_extract(runs_dir)
logdir = os.path.join(runs_dir, default_run)
else:
logging.info("Using weights found in {}".format(FLAGS.logdir))
logdir = FLAGS.logdir
#这段在做的就是找到权重的存放路径并赋值给logdir
# Loading hyperparameters from logdir,从下载的权重的存放路径加载超参
hypes = tv_utils.load_hypes_from_logdir(logdir, base_path='hypes')
logging.info("Hypes loaded successfully.")
# Loading tv modules (encoder.py, decoder.py, eval.py) from logdir
modules = tv_utils.load_modules_from_logdir(logdir)
logging.info("Modules loaded successfully. Starting to build tf graph.")
# Create tf graph and build module.
with tf.Graph().as_default():
# Create placeholder for input
image_pl = tf.placeholder(tf.float32)
image = tf.expand_dims(image_pl, 0)
# build Tensorflow graph using the model from logdir
prediction = core.build_inference_graph(hypes, modules,
image=image)
logging.info("Graph build successfully.")
# Create a session for running Ops on the Graph.
sess = tf.Session()
saver = tf.train.Saver()#模型保存,先要创建一个saver对象
# Load weights from logdir
#cjh-2019-4-16
# logdir=r'/home/xue/MultiNet-master/submodules/KittiClass/KittiClass_postpaper'
logdir = r'E:/lixueqian/2019/new_method/MultiNet-master1/submodules/KittiClass/KittiClass_postpaper'
core.load_weights(logdir, sess, saver)
logging.info("Weights loaded successfully.")
input = FLAGS.input
logging.info("Starting inference using {} as input".format(input)) #使用的方法
# Load and resize input image
# input=r'/home/xue/MultiNet-master/submodules/KittiClass/KittiClass-master/DATA/demo/007034.png'
image = scp.misc.imread(input)
if hypes['jitter']['reseize_image']:
# Resize input only, if specified in hypes
image_height = hypes['jitter']['image_height']
image_width = hypes['jitter']['image_width']
image = scp.misc.imresize(image, size=(image_height, image_width),
interp='cubic')
# Run KittiSeg model on image
feed = {image_pl: image} #为placeholder赋值
softmax_road, _ = prediction['softmax']#
output = sess.run([softmax_road], feed_dict=feed)
# output是分类的概率
# Get predicted class
highway = (np.argmax(output[0][0]) == 0)
# highway是将output转化为bool值
# Draw resulting output image
new_img = road_draw(image, highway)
# Save output images to disk.
if FLAGS.output is None:
output_base_name = input
out_image_name = output_base_name.split('.')[0] + '_out.png'
else:
out_image_name = FLAGS.output
scp.misc.imsave(out_image_name, new_img)
logging.info("")
logging.info("Output image has been saved to: {}".format(
os.path.realpath(out_image_name)))
if __name__ == '__main__':
tf.app.run() #使用flags又解析了一次,tf.app.run的作用仅仅是指定main主函数和使用flags再解析一次输入
Kitticlass demo程序学习
最新推荐文章于 2023-09-22 10:58:47 发布