Kittiseg

 

 

1、inputs/kitti_seg_inputs.py         确保你的本地训练文件夹中,已经下载了正确的数据集。然后将这些数据解压并返回一个DataSet实例的字典。

2、train.py

tf.app.flags.FLAGS.hypes   :hypes/KittiSeg.json    相当于hypes地址   

print("tf.app.flags.FLAGS.hypes",tf.app.flags.FLAGS.hypes )

with open(tf.app.flags.FLAGS.hypes, 'r') as f:         with open 读写函数,‘r’ 表示读

utils.load_plugins():  hypes/KittiSeg.json 为basedir

import ast

#将string 转化为dict 并且验证执行合法的类型

mod_dict = ast.literal_eval(tf.app.flags.FLAGS.mod)

dict_merge(hypes, mod_dict)

 两个dict 合并:

dict1={1:[1,11,111],2:[2,22,222]}  
dict2={3:[3,33,333],4:[4,44,444]}  
方法1、

dictMerged1=dict(dict1.items()+dict2.items())  

方法2、
dictMerged2=dict(dict1, **dict2)

 合并
{1:[1,11,111],2:[2,22,222],3:[3,33,333],4:[4,44,444]}

os.environ['TV_DIR_RUNS'] = os.path.join(os.environ['TV_DIR_RUNS'], 'KittiSeg')

os.path.join(path,name):连接目录与文件名或目录

>>> os.path.join('c:\\Python','a.txt')
'c:\\Python\\a.txt'
>>> os.path.join('c:\\Python','f1')
'c:\\Python\\f1'
>>> 

utils.set_dirs(hypes, tf.app.flags.FLAGS.hypes)

->

# Set base_path
    if 'base_path' not in hypes['dirs']:
        base_path = os.path.dirname(os.path.realpath(hypes_fname))
        hypes['dirs']['base_path'] = base_path
    else:
        base_path = hypes['dirs']['base_path']

 os.path.dirname(path):返回文件路径

>>> os.path.dirname('c:\\Python\\a.txt')
'c:\\Python'

 os.path.realpath():返回hypes_fname 真实路径

 print os.path.realpath()

#base_path    hypes/

  # Set output dir
    if 'output_dir' not in hypes['dirs']:
        if 'TV_DIR_RUNS' in os.environ:
            runs_dir = os.path.join(base_path, os.environ['TV_DIR_RUNS'])
        else:
            runs_dir = os.path.join(base_path, '../RUNS')

#runs_dir   RUNS/

# test for project dir
        if hasattr(FLAGS, 'project') and FLAGS.project is not None:
            runs_dir = os.path.join(runs_dir, FLAGS.project)

hasattr(FLAGS,'project')   true,

hasattr(object, name)

   判断object对象中是否存在name属性,当然对于python的对象而言,属性包含变量和方法;有则返回True,没有则返回False;需要注意的是name参数是string类型,所以不管是要判断变量还是方法,其名称都以字符串形式传参;getattr和setattr也同样;

>>> 
>>> class A():
    name = 'python'
    def func(self):
        return 'A()类的方法func()'

    
>>> 
>>> hasattr(A, 'name')
True
>>> 
>>> hasattr(A, 'age')
False
>>> 
>>> hasattr(A, 'func')
True
>>>
    if 'output_dir' not in hypes['dirs']:
        if 'TV_DIR_RUNS' in os.environ:
            runs_dir = os.path.join(base_path, os.environ['TV_DIR_RUNS'])
        else:
            runs_dir = os.path.join(base_path, '../RUNS')

        # test for project dir
        if hasattr(FLAGS, 'project') and FLAGS.project is not None:
            runs_dir = os.path.join(runs_dir, FLAGS.project)

        if not FLAGS.save and FLAGS.name is None:
            output_dir = os.path.join(runs_dir, 'debug')
        else:
            json_name = hypes_fname.split('/')[-1].replace('.json', '')
            date = datetime.now().strftime('%Y_%m_%d_%H.%M')
            if FLAGS.name is not None:
                json_name = FLAGS.name + "_" + json_name
            run_name = '%s_%s' % (json_name, date)
            output_dir = os.path.join(runs_dir, run_name)

        hypes['dirs']['output_dir'] = output_dir

 run_name : KittiSeg.json去掉.json +'_'+date

 output_dir = RUNS/KittiSeg_date/

    if 'data_dir' not in hypes['dirs']:
        if 'TV_DIR_DATA' in os.environ:
            data_dir = os.path.join(base_path, os.environ['TV_DIR_DATA'])
        else:
            data_dir = os.path.join(base_path, '../DATA')

        hypes['dirs']['data_dir'] = data_dir

data_dir=DATA/

def _add_paths_to_sys(hypes):
    """
    Add all module dirs to syspath.

    This adds the dirname of all modules to path.

    Parameters
    ----------
    hypes : dict
        Hyperparameters
    """
    base_path = hypes['dirs']['base_path']
    if 'path' in hypes:
            for path in hypes['path']:
                path = os.path.realpath(os.path.join(base_path, path))
                sys.path.insert(1, path)
    return

path.insert(1,path)设置优先搜索路径还是在incl/中

utils.set_dirs(hypes, tf.app.flags.FLAGS.hypes)

#base_path    hypes/

#runs_dir   RUNS/

output_dir = RUNS/KittiSeg_date/

data_dir=DATA/

def maybe_download_and_extract(hypes):
    """
    Download the data if it isn't downloaded by now.

    Parameters
    ----------
    hypes : dict
        Hyperparameters
    """
    f = os.path.join(hypes['dirs']['base_path'], hypes['model']['input_file'])
    data_input = imp.load_source("input", f)
    if hasattr(data_input, 'maybe_download_and_extract'):
        data_input.maybe_download_and_extract(hypes)

f=inputs/kitti_seg_input.py

imp.load_source(name,pathname[,file])的作用把源文件pathname导入到name模块中,name可以是自定义的名字或者内置的模块名称。

把kitti_seg_input.py 放到input 模块中 ,data_input可直接调用其中函数

假设在路径E:/Code/Python3/下有一个文件test.py, 内容如下:

def myadd(x, y):
	return(x + y)
import imp
m = imp.load_source('mymod', 'E:/Code/Python3/test.py')
 
# 方法一
a = m.myadd(4, 10)
print(a)
 
# 方法二
import mymod
a = mymod.myadd(4, 10)
print(a)

如果data_input(kitti_seg_input.py)中有maybe_download_and_extract()函数

执行函数maybe_download_and_extract(hypes)

def maybe_download_and_extract(hypes):
    """ Downloads, extracts and prepairs data.

    """

    data_dir = hypes['dirs']['data_dir']
    if not os.path.exists(data_dir):
        os.makedirs(data_dir)

    data_road_zip = os.path.join(data_dir, 'data_road.zip')
    vgg_weights = os.path.join(data_dir, 'vgg16.npy')
    kitti_road_dir = os.path.join(data_dir, 'data_road/')

    if os.path.exists(vgg_weights) and os.path.exists(kitti_road_dir):
        return

    import tensorvision.utils as utils
    import zipfile
    from shutil import copy2

    # Download KITTI DATA
    kitti_data_url = hypes['data']['kitti_url']

    if kitti_data_url == '':
        logging.error("Data URL for Kitti Data not provided.")
        url = "http://www.cvlibs.net/download.php?file=data_road.zip"
        logging.error("Please visit: {}".format(url))
        logging.error("and request Kitti Download link.")
        logging.error("Enter URL in hypes/kittiSeg.json")
        exit(1)
    if not kitti_data_url[-19:] == 'kitti/data_road.zip':
        logging.error("Wrong url.")
        url = "http://www.cvlibs.net/download.php?file=data_road.zip"
        logging.error("Please visit: {}".format(url))
        logging.error("and request Kitti Download link.")
        logging.error("Enter URL in hypes/kittiSeg.json")
        exit(1)

    logging.info("Downloading Kitti Road Data.")
    utils.download(kitti_data_url, data_dir)
    # Extract and prepare KITTI DATA
    logging.info("Extracting kitti_road data.")
    zipfile.ZipFile(data_road_zip, 'r').extractall(data_dir)
    kitti_road_dir = os.path.join(data_dir, 'data_road/')

    logging.info("Preparing kitti_road data.")

    train_txt = "data/train3.txt"
    val_txt = "data/val3.txt"
    copy2(train_txt, kitti_road_dir)
    copy2(val_txt, kitti_road_dir)

    vgg_url = kitti_data_url = hypes['data']['vgg_url']
    # Download VGG DATA
    download_command = "wget {} -P {}".format(vgg_url, data_dir)
    logging.info("Downloading VGG weights.")
    utils.download(vgg_url, data_dir)
    return

    提前把data_road/     vgg16.npy,data_road.zip放到DATA中,直接返回,不执行操作。

   data_dir =DATA/

    data_road_zip=DATA/data_road.zip

     vgg_weights= DATA/vgg16.npy

     kitti_road_dir = DATA/data_road/


def initialize_training_folder(hypes, files_dir="model_files", logging=True):
target_dir = os.path.join(hypes['dirs']['output_dir'], files_dir)
    if not os.path.exists(target_dir):
        os.makedirs(target_dir)

    image_dir = os.path.join(hypes['dirs']['output_dir'], "images")
    if not os.path.exists(image_dir):
        os.makedirs(image_dir)

    hypes['dirs']['image_dir'] = image_dir

    # Creating an additional logging saving the console outputs
    # into the training folder
    if logging:
        logging_file = os.path.join(hypes['dirs']['output_dir'], "output.log")
        utils.create_filewrite_handler(logging_file)

    # TODO: read more about loggers and make file logging neater.
    #target_file = RUNS/KittiSeg_date/model_files/hypes.json
    target_file = os.path.join(target_dir, 'hypes.json')
    with open(target_file, 'w') as outfile:
        json.dump(hypes, outfile, indent=2, sort_keys=True)
    _copy_parameters_to_traindir(
        hypes, hypes['model']['input_file'], "data_input.py", target_dir)
    _copy_parameters_to_traindir(
        hypes, hypes['model']['architecture_file'], "architecture.py",
        target_dir)
    _copy_parameters_to_traindir(
        hypes, hypes['model']['objective_file'], "objective.py", target_dir)
    _copy_parameters_to_traindir(
        hypes, hypes['model']['optimizer_file'], "solver.py", target_dir)
    _copy_parameters_to_traindir(
        hypes, hypes['model']['evaluator_file'], "eval.py", target_dir)

#output_dir = RUNS/KittiSeg_date/

#target_dir = RUNS/KittiSeg_date/model_files/

#image_dir = RUNS/KittiSeg_date/images/

 自己写数据集的时候标签一定不能多一个空格行,谨记

#logging_dir = RUNS/KittiSeg_date/output.log

#target_file = RUNS/KittiSeg_date/model_files/hypes.json

train.initialize_training_folder(hypes)

 train.initialize_training_folder(hypes)

#初始化输出文件

def do_training(hypes):
 modules = utils.load_modules_from_hypes(hypes)
 with tf.Session() as sess:

        # build the graph based on the loaded modules
        with tf.name_scope("Queues"):
            queue = modules['input'].create_queues(hypes, 'train')

        tv_graph = core.build_training_graph(hypes, queue, modules)

        # prepaire the tv session
        tv_sess = core.start_tv_session(hypes)

        with tf.name_scope('Validation'):
            tf.get_variable_scope().reuse_variables()
            image_pl = tf.placeholder(tf.float32)
            image = tf.expand_dims(image_pl, 0)
            image.set_shape([1, None, None, 3])
            inf_out = core.build_inference_graph(hypes, modules,
                                                 image=image)
            tv_graph['image_pl'] = image_pl
            tv_graph['inf_out'] = inf_out

        # Start the data load
        modules['input'].start_enqueuing_threads(hypes, queue, 'train', sess)

        # And then after everything is built, start the training loop.
        run_training(hypes, modules, tv_graph, tv_sess)

        # stopping input Threads
        tv_sess['coord'].request_stop()
        tv_sess['coord'].join(tv_sess['threads'])

modules = utils.load_modules_from_hypes(hypes)

#导入模型训练程序

def load_modules_from_hypes(hypes, postfix=""):
    """Load all modules from the files specified in hypes.

    Namely the modules loaded are:
    input_file, architecture_file, objective_file, optimizer_file

    Parameters
    ----------
    hypes : dict
        Hyperparameters

    Returns
    -------
    hypes, data_input, arch, objective, solver
    """
    modules = {}
    #base_path = hypes/
    base_path = hypes['dirs']['base_path']

    # _add_paths_to_sys(hypes)
    #f=inputs/kitti_seg_input.py
    f = os.path.join(base_path, hypes['model']['input_file'])
    #将kitti_seg_input.py 导入到 input_%s,模块名data_input
    data_input = imp.load_source("input_%s" % postfix, f)
    modules['input'] = data_input  
    #kitti_seg_input.py   modules['input']

    f = os.path.join(base_path, hypes['model']['architecture_file'])
    arch = imp.load_source("arch_%s" % postfix, f)
    modules['arch'] = arch
    #fcn8_vgg.py    modules['arch']

    f = os.path.join(base_path, hypes['model']['objective_file'])
    objective = imp.load_source("objective_%s" % postfix, f)
    modules['objective'] = objective
    #kitti_multiloss.py    modules['objective']

    f = os.path.join(base_path, hypes['model']['optimizer_file'])
    solver = imp.load_source("solver_%s" % postfix, f)
    modules['solver'] = solver
    #generic_optimazer.py   modules['solver']
    f = os.path.join(base_path, hypes['model']['evaluator_file'])
    eva = imp.load_source("evaluator_%s" % postfix, f)
    modules['eval'] = eva
    #kitti_eval.py  modules['eval']

    return modules

modules['input']     kitti_seg_input.py

modules['arch']    fcn8_vgg.py

modules['objective']     kitti_multiloss.py

modules['solver']    generic_optimazer.py

modules['eval']        kitti_eval.py

with tf.Session() as sess:

#运行默认图

 with tf.name_scope("Queues"):
            queue = modules['input'].create_queues(hypes, 'train')

在inputs/kitti_seg_input.py下面找 create_queues()

def create_queues(hypes, phase):
    """Create Queues."""
    arch = hypes['arch']
    dtypes = [tf.float32, tf.int32]

    shape_known = hypes['jitter']['reseize_image'] \
        or hypes['jitter']['crop_patch']

    if shape_known:
        if hypes['jitter']['crop_patch']:
            height = hypes['jitter']['patch_height']
            width = hypes['jitter']['patch_width']
        else:
            height = hypes['jitter']['image_height']
            width = hypes['jitter']['image_width']
        channel = hypes['arch']['num_channels']
        num_classes = hypes['arch']['num_classes']
        shapes = [[height, width, channel],
                  [height, width, num_classes]]
    else:
        shapes = None

    capacity = 50
    q = tf.FIFOQueue(capacity=50, dtypes=dtypes, shapes=shapes)
    tf.summary.scalar("queue/%s/fraction_of_%d_full" %
                      (q.name + "_" + phase, capacity),
                      math_ops.cast(q.size(), tf.float32) * (1. / capacity))

    return q

#改reseize_image 为true

shapes =[[384,1248,3],[384,1248,2]]

#使用tf.FIFOQueue类创建一个先入先出队列.

#capacity:指定队列中的元素数量的上限。

#dtypes:DType对象的列表。dtypes的长度必须等于每个队列元素中张量的数量。

#shapes:(可选项)

#标量数据汇总和记录使用tf.summary.scala

gou 

with tf.name_scope("Inputs"):
        image, labels = data_input.inputs(hypes, queue, phase='train')

 train 数据 输入图

构建


def _make_data_gen(hypes, phase, data_dir):
    """Return a data generator that outputs image samples.

    @ Returns
    image: integer array of shape [height, width, 3].
    Representing RGB value of each pixel.
    gt_image: boolean array of shape [height, width, num_classes].
    Set `gt_image[i,j,k] == 1` if and only if pixel i,j
    is assigned class k. `gt_image[i,j,k] == 0` otherwise.

    [Alternativly make gt_image[i,j,*] a valid propability
    distribution.]
    """
    if phase == 'train':
        data_file = hypes['data']["train_file"]
    elif phase == 'val':
        data_file = hypes['data']["val_file"]
    else:
        assert False, "Unknown Phase %s" % phase

    data_file = os.path.join(data_dir, data_file)

    road_color = np.array(hypes['data']['road_color'])
    background_color = np.array(hypes['data']['background_color'])

    data = _load_gt_file(hypes, data_file)

    for image, gt_image in data:

        gt_bg = np.all(gt_image == background_color, axis=2)
        gt_road = np.all(gt_image == road_color, axis=2)

        assert(gt_road.shape == gt_bg.shape)
        shape = gt_bg.shape
        gt_bg = gt_bg.reshape(shape[0], shape[1], 1)
        gt_road = gt_road.reshape(shape[0], shape[1], 1)

        gt_image = np.concatenate((gt_bg, gt_road), axis=2)

        if phase == 'val':
            yield image, gt_image
        elif phase == 'train':

            yield jitter_input(hypes, image, gt_image)

            yield jitter_input(hypes, np.fliplr(image), np.fliplr(gt_image))

data_file = os.path.join(data_dir, data_file)

#data_file = DATA/road_data/train3.txt

data = _load_gt_file(hypes, data_file)

base_path = os.path.realpath(os.path.dirname(data_file))

def run_training(hypes, modules, tv_graph, tv_sess, start_step=0):

 

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 5
    评论
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值