Kitticlass demo程序学习

#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Author: Marvin Teichmann


"""
Classify an image using KittiClass.

Input: Image
Output: Image (with Cars plotted in Green)

Utilizes: Trained KittiClass weights. If no logdir is given,
pretrained weights will be downloaded and used.

Usage:
python demo.py --input data/demo.png [--output output]
                [--logdir /path/to/weights] [--gpus 0]


"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import json
import logging
import os
import sys

import collections

# configure logging
#定义日志,记录关键节点信息
logging.basicConfig(format='%(asctime)s %(levelname)s %(message)s',
                    level=logging.INFO,
                    stream=sys.stdout)

# https://github.com/tensorflow/tensorflow/issues/2034#issuecomment-220820070
import numpy as np
import scipy as scp
import scipy.misc
import tensorflow as tf


flags = tf.app.flags #用于接受命令行传递参数,处理命令行参数的解析工作,(JSON)
FLAGS = flags.FLAGS #构造一个解析器FLAGS

sys.path.insert(1, 'incl')#定义搜索的优先顺序
#新添加的目录会优先于其他目录被import检查
#我的项目中没有这个文件incl,所以没有用到,见65行
try:
    # Check whether setup was done correctly

    import tensorvision.utils as tv_utils
    import tensorvision.core as core
except ImportError:
    # You forgot to initialize submodules
    logging.error("Could not import the submodules.")
    logging.error("Please execute:"
                  "'git submodule update --init --recursive'")
    exit(1)

#原本程序中是None,我没有incl文件说以人为指定第二个参数
flags.DEFINE_string('logdir', 'logdir',
                    'Path to logdir.')
#tf.app.flags.DEFINE_string() 定义一个用于接收string 类型数值的变量
#三个参数分别是:变量名称、默认值、用法描述,
#cjh-2019-4-16
flags.DEFINE_string('input', 'DATA/demo/007007.png',
                    'Image to apply KittiSeg.')
flags.DEFINE_string('output', 'DATA/demo/007007_.png',
                    'Image to apply KittiSeg.')


default_run = 'KittiClass_postpaper' #文件夹的名字
#可能是加载欲训练网络的权重,
weights_url = ("ftp://mi.eng.cam.ac.uk/"
               "pub/mttt2/models/KittiClass_postpaper.zip")
#路径为ftp://mi.eng.cam.ac.uk/pub/mttt2/models/
#自行手动下载

from PIL import Image, ImageDraw, ImageFont

#在图片上画结果
def road_draw(image, highway):
    im = Image.fromarray(image.astype('uint8'))
    draw = ImageDraw.Draw(im)

    #fnt = ImageFont.truetype('FreeMono/FreeMonoBold.ttf', 40)
    fnt = ImageFont.truetype('simhei.ttf', 40)
    shape = image.shape

    if highway:
        draw.text((65, 10), "Highway",
                  font=fnt, fill=(255, 255, 0, 255))

        draw.ellipse([10, 10, 55, 55], fill=(255, 255, 0, 255),
                     outline=(255, 255, 0, 255))
    else:
        draw.text((65, 10), "small road",
                  font=fnt, fill=(255, 0, 0, 255))

        draw.ellipse([10, 10, 55, 55], fill=(255, 0, 0, 255),
                     outline=(255, 0, 0, 255))

    return np.array(im).astype('float32')

#runs_dir为输入值,目的是构造路径,整个函数的目的就是下载权重,
# 然后解压zip文件,如果存在则不进行任何操作。
def maybe_download_and_extract(runs_dir):
    logdir = os.path.join(runs_dir, default_run)#构造路径

    if os.path.exists(logdir):
        # weights are downloaded. Nothing to do
        return
    #解压文件
    import zipfile
    download_name = tv_utils.download(weights_url, runs_dir)#下载

    logging.info("Extracting KittiSeg_pretrained.zip")

    zipfile.ZipFile(download_name, 'r').extractall(runs_dir)

    return

#重新定义图片的大小,第一个是原图,第二个是标记图像
def resize_label_image(image, gt_image, image_height, image_width):
    image = scp.misc.imresize(image, size=(image_height, image_width),
                              interp='cubic')
    shape = gt_image.shape
    gt_image = scp.misc.imresize(gt_image, size=(image_height, image_width),
                                 interp='nearest')

    return image, gt_image


def main(_):
    ##设置运行代码的GPU
    tv_utils.set_gpus_to_use()
    if FLAGS.input is None:
        logging.error("No input was given.")
        logging.info(
            "Usage: python demo.py --input data/test.png "
            "[--output output] [--logdir /path/to/weights] "
            "[--gpus GPUs_to_use] ")
        exit(1)

    if FLAGS.logdir is None:
        # Download and use weights from the MultiNet Paper
        if 'TV_DIR_RUNS' in os.environ:
            runs_dir = os.path.join(os.environ['TV_DIR_RUNS'],
                                    'KittiClass')
        else:
            runs_dir = 'RUNS'
        maybe_download_and_extract(runs_dir)
        logdir = os.path.join(runs_dir, default_run)
    else:
        logging.info("Using weights found in {}".format(FLAGS.logdir))
        logdir = FLAGS.logdir
    #这段在做的就是找到权重的存放路径并赋值给logdir
    # Loading hyperparameters from logdir,从下载的权重的存放路径加载超参
    hypes = tv_utils.load_hypes_from_logdir(logdir, base_path='hypes')

    logging.info("Hypes loaded successfully.")

    # Loading tv modules (encoder.py, decoder.py, eval.py) from logdir
    modules = tv_utils.load_modules_from_logdir(logdir)
    logging.info("Modules loaded successfully. Starting to build tf graph.")

    # Create tf graph and build module.
    with tf.Graph().as_default():
        # Create placeholder for input
        image_pl = tf.placeholder(tf.float32)
        image = tf.expand_dims(image_pl, 0)

        # build Tensorflow graph using the model from logdir
        prediction = core.build_inference_graph(hypes, modules,
                                                image=image)

        logging.info("Graph build successfully.")

        # Create a session for running Ops on the Graph.
        sess = tf.Session()
        saver = tf.train.Saver()#模型保存,先要创建一个saver对象

        # Load weights from logdir
        #cjh-2019-4-16
        # logdir=r'/home/xue/MultiNet-master/submodules/KittiClass/KittiClass_postpaper'
        logdir = r'E:/lixueqian/2019/new_method/MultiNet-master1/submodules/KittiClass/KittiClass_postpaper'
        core.load_weights(logdir, sess, saver)

        logging.info("Weights loaded successfully.")

    input = FLAGS.input
    logging.info("Starting inference using {} as input".format(input))  #使用的方法

    # Load and resize input image
    # input=r'/home/xue/MultiNet-master/submodules/KittiClass/KittiClass-master/DATA/demo/007034.png'
    image = scp.misc.imread(input)
    if hypes['jitter']['reseize_image']:
        # Resize input only, if specified in hypes
        image_height = hypes['jitter']['image_height']
        image_width = hypes['jitter']['image_width']
        image = scp.misc.imresize(image, size=(image_height, image_width),
                                  interp='cubic')

    # Run KittiSeg model on image
    feed = {image_pl: image} #为placeholder赋值
    softmax_road, _ = prediction['softmax']#
    output = sess.run([softmax_road], feed_dict=feed)
    # output是分类的概率
    # Get predicted class
    highway = (np.argmax(output[0][0]) == 0)
    # highway是将output转化为bool值
    # Draw resulting output image
    new_img = road_draw(image, highway)

    # Save output images to disk.
    if FLAGS.output is None:
        output_base_name = input
        out_image_name = output_base_name.split('.')[0] + '_out.png'
    else:
        out_image_name = FLAGS.output

    scp.misc.imsave(out_image_name, new_img)

    logging.info("")
    logging.info("Output image has been saved to: {}".format(
        os.path.realpath(out_image_name)))

if __name__ == '__main__':
    tf.app.run()  #使用flags又解析了一次,tf.app.run的作用仅仅是指定main主函数和使用flags再解析一次输入

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值