1、inputs/kitti_seg_inputs.py 确保你的本地训练文件夹中,已经下载了正确的数据集。然后将这些数据解压并返回一个DataSet实例的字典。
2、train.py
tf.app.flags.FLAGS.hypes :hypes/KittiSeg.json 相当于hypes地址
print("tf.app.flags.FLAGS.hypes",tf.app.flags.FLAGS.hypes )
with open(tf.app.flags.FLAGS.hypes, 'r') as f: with open 读写函数,‘r’ 表示读
utils.load_plugins(): hypes/KittiSeg.json 为basedir
import ast
#将string 转化为dict 并且验证执行合法的类型
mod_dict = ast.literal_eval(tf.app.flags.FLAGS.mod)
dict_merge(hypes, mod_dict)
两个dict 合并:
dict1={1:[1,11,111],2:[2,22,222]}
dict2={3:[3,33,333],4:[4,44,444]}
方法1、
dictMerged1=dict(dict1.items()+dict2.items())
方法2、
dictMerged2=dict(dict1, **dict2)
合并
{1:[1,11,111],2:[2,22,222],3:[3,33,333],4:[4,44,444]}
os.environ['TV_DIR_RUNS'] = os.path.join(os.environ['TV_DIR_RUNS'], 'KittiSeg')
os.path.join(path,name):连接目录与文件名或目录
>>> os.path.join('c:\\Python','a.txt')
'c:\\Python\\a.txt'
>>> os.path.join('c:\\Python','f1')
'c:\\Python\\f1'
>>>
utils.set_dirs(hypes, tf.app.flags.FLAGS.hypes)
->
# Set base_path
if 'base_path' not in hypes['dirs']:
base_path = os.path.dirname(os.path.realpath(hypes_fname))
hypes['dirs']['base_path'] = base_path
else:
base_path = hypes['dirs']['base_path']
os.path.dirname(path):返回文件路径
>>> os.path.dirname('c:\\Python\\a.txt')
'c:\\Python'
os.path.realpath():返回hypes_fname 真实路径
print os.path.realpath()
#base_path hypes/
# Set output dir
if 'output_dir' not in hypes['dirs']:
if 'TV_DIR_RUNS' in os.environ:
runs_dir = os.path.join(base_path, os.environ['TV_DIR_RUNS'])
else:
runs_dir = os.path.join(base_path, '../RUNS')
#runs_dir RUNS/
# test for project dir
if hasattr(FLAGS, 'project') and FLAGS.project is not None:
runs_dir = os.path.join(runs_dir, FLAGS.project)
hasattr(FLAGS,'project') true,
hasattr(object, name)
判断object对象中是否存在name属性,当然对于python的对象而言,属性包含变量和方法;有则返回True,没有则返回False;需要注意的是name参数是string类型,所以不管是要判断变量还是方法,其名称都以字符串形式传参;getattr和setattr也同样;
>>>
>>> class A():
name = 'python'
def func(self):
return 'A()类的方法func()'
>>>
>>> hasattr(A, 'name')
True
>>>
>>> hasattr(A, 'age')
False
>>>
>>> hasattr(A, 'func')
True
>>>
if 'output_dir' not in hypes['dirs']:
if 'TV_DIR_RUNS' in os.environ:
runs_dir = os.path.join(base_path, os.environ['TV_DIR_RUNS'])
else:
runs_dir = os.path.join(base_path, '../RUNS')
# test for project dir
if hasattr(FLAGS, 'project') and FLAGS.project is not None:
runs_dir = os.path.join(runs_dir, FLAGS.project)
if not FLAGS.save and FLAGS.name is None:
output_dir = os.path.join(runs_dir, 'debug')
else:
json_name = hypes_fname.split('/')[-1].replace('.json', '')
date = datetime.now().strftime('%Y_%m_%d_%H.%M')
if FLAGS.name is not None:
json_name = FLAGS.name + "_" + json_name
run_name = '%s_%s' % (json_name, date)
output_dir = os.path.join(runs_dir, run_name)
hypes['dirs']['output_dir'] = output_dir
run_name : KittiSeg.json去掉.json +'_'+date
output_dir = RUNS/KittiSeg_date/
if 'data_dir' not in hypes['dirs']:
if 'TV_DIR_DATA' in os.environ:
data_dir = os.path.join(base_path, os.environ['TV_DIR_DATA'])
else:
data_dir = os.path.join(base_path, '../DATA')
hypes['dirs']['data_dir'] = data_dir
data_dir=DATA/
def _add_paths_to_sys(hypes):
"""
Add all module dirs to syspath.
This adds the dirname of all modules to path.
Parameters
----------
hypes : dict
Hyperparameters
"""
base_path = hypes['dirs']['base_path']
if 'path' in hypes:
for path in hypes['path']:
path = os.path.realpath(os.path.join(base_path, path))
sys.path.insert(1, path)
return
path.insert(1,path)设置优先搜索路径还是在incl/中
utils.set_dirs(hypes, tf.app.flags.FLAGS.hypes)
#base_path hypes/
#runs_dir RUNS/
output_dir = RUNS/KittiSeg_date/
data_dir=DATA/
def maybe_download_and_extract(hypes):
"""
Download the data if it isn't downloaded by now.
Parameters
----------
hypes : dict
Hyperparameters
"""
f = os.path.join(hypes['dirs']['base_path'], hypes['model']['input_file'])
data_input = imp.load_source("input", f)
if hasattr(data_input, 'maybe_download_and_extract'):
data_input.maybe_download_and_extract(hypes)
f=inputs/kitti_seg_input.py
imp.load_source(name,pathname[,file])的作用把源文件pathname导入到name模块中,name可以是自定义的名字或者内置的模块名称。
把kitti_seg_input.py 放到input 模块中 ,data_input可直接调用其中函数
假设在路径E:/Code/Python3/下有一个文件test.py, 内容如下:
def myadd(x, y):
return(x + y)
import imp
m = imp.load_source('mymod', 'E:/Code/Python3/test.py')
# 方法一
a = m.myadd(4, 10)
print(a)
# 方法二
import mymod
a = mymod.myadd(4, 10)
print(a)
如果data_input(kitti_seg_input.py)中有maybe_download_and_extract()函数
执行函数maybe_download_and_extract(hypes)
def maybe_download_and_extract(hypes):
""" Downloads, extracts and prepairs data.
"""
data_dir = hypes['dirs']['data_dir']
if not os.path.exists(data_dir):
os.makedirs(data_dir)
data_road_zip = os.path.join(data_dir, 'data_road.zip')
vgg_weights = os.path.join(data_dir, 'vgg16.npy')
kitti_road_dir = os.path.join(data_dir, 'data_road/')
if os.path.exists(vgg_weights) and os.path.exists(kitti_road_dir):
return
import tensorvision.utils as utils
import zipfile
from shutil import copy2
# Download KITTI DATA
kitti_data_url = hypes['data']['kitti_url']
if kitti_data_url == '':
logging.error("Data URL for Kitti Data not provided.")
url = "http://www.cvlibs.net/download.php?file=data_road.zip"
logging.error("Please visit: {}".format(url))
logging.error("and request Kitti Download link.")
logging.error("Enter URL in hypes/kittiSeg.json")
exit(1)
if not kitti_data_url[-19:] == 'kitti/data_road.zip':
logging.error("Wrong url.")
url = "http://www.cvlibs.net/download.php?file=data_road.zip"
logging.error("Please visit: {}".format(url))
logging.error("and request Kitti Download link.")
logging.error("Enter URL in hypes/kittiSeg.json")
exit(1)
logging.info("Downloading Kitti Road Data.")
utils.download(kitti_data_url, data_dir)
# Extract and prepare KITTI DATA
logging.info("Extracting kitti_road data.")
zipfile.ZipFile(data_road_zip, 'r').extractall(data_dir)
kitti_road_dir = os.path.join(data_dir, 'data_road/')
logging.info("Preparing kitti_road data.")
train_txt = "data/train3.txt"
val_txt = "data/val3.txt"
copy2(train_txt, kitti_road_dir)
copy2(val_txt, kitti_road_dir)
vgg_url = kitti_data_url = hypes['data']['vgg_url']
# Download VGG DATA
download_command = "wget {} -P {}".format(vgg_url, data_dir)
logging.info("Downloading VGG weights.")
utils.download(vgg_url, data_dir)
return
提前把data_road/ vgg16.npy,data_road.zip放到DATA中,直接返回,不执行操作。
data_dir =DATA/
data_road_zip=DATA/data_road.zip
vgg_weights= DATA/vgg16.npy
kitti_road_dir = DATA/data_road/
def initialize_training_folder(hypes, files_dir="model_files", logging=True):
target_dir = os.path.join(hypes['dirs']['output_dir'], files_dir)
if not os.path.exists(target_dir):
os.makedirs(target_dir)
image_dir = os.path.join(hypes['dirs']['output_dir'], "images")
if not os.path.exists(image_dir):
os.makedirs(image_dir)
hypes['dirs']['image_dir'] = image_dir
# Creating an additional logging saving the console outputs
# into the training folder
if logging:
logging_file = os.path.join(hypes['dirs']['output_dir'], "output.log")
utils.create_filewrite_handler(logging_file)
# TODO: read more about loggers and make file logging neater.
#target_file = RUNS/KittiSeg_date/model_files/hypes.json
target_file = os.path.join(target_dir, 'hypes.json')
with open(target_file, 'w') as outfile:
json.dump(hypes, outfile, indent=2, sort_keys=True)
_copy_parameters_to_traindir(
hypes, hypes['model']['input_file'], "data_input.py", target_dir)
_copy_parameters_to_traindir(
hypes, hypes['model']['architecture_file'], "architecture.py",
target_dir)
_copy_parameters_to_traindir(
hypes, hypes['model']['objective_file'], "objective.py", target_dir)
_copy_parameters_to_traindir(
hypes, hypes['model']['optimizer_file'], "solver.py", target_dir)
_copy_parameters_to_traindir(
hypes, hypes['model']['evaluator_file'], "eval.py", target_dir)
#output_dir = RUNS/KittiSeg_date/
#target_dir = RUNS/KittiSeg_date/model_files/
#image_dir = RUNS/KittiSeg_date/images/
自己写数据集的时候标签一定不能多一个空格行,谨记
#logging_dir = RUNS/KittiSeg_date/output.log
#target_file = RUNS/KittiSeg_date/model_files/hypes.json
train.initialize_training_folder(hypes)
train.initialize_training_folder(hypes)
#初始化输出文件
def do_training(hypes):
modules = utils.load_modules_from_hypes(hypes)
with tf.Session() as sess:
# build the graph based on the loaded modules
with tf.name_scope("Queues"):
queue = modules['input'].create_queues(hypes, 'train')
tv_graph = core.build_training_graph(hypes, queue, modules)
# prepaire the tv session
tv_sess = core.start_tv_session(hypes)
with tf.name_scope('Validation'):
tf.get_variable_scope().reuse_variables()
image_pl = tf.placeholder(tf.float32)
image = tf.expand_dims(image_pl, 0)
image.set_shape([1, None, None, 3])
inf_out = core.build_inference_graph(hypes, modules,
image=image)
tv_graph['image_pl'] = image_pl
tv_graph['inf_out'] = inf_out
# Start the data load
modules['input'].start_enqueuing_threads(hypes, queue, 'train', sess)
# And then after everything is built, start the training loop.
run_training(hypes, modules, tv_graph, tv_sess)
# stopping input Threads
tv_sess['coord'].request_stop()
tv_sess['coord'].join(tv_sess['threads'])
modules = utils.load_modules_from_hypes(hypes)
#导入模型训练程序
def load_modules_from_hypes(hypes, postfix=""):
"""Load all modules from the files specified in hypes.
Namely the modules loaded are:
input_file, architecture_file, objective_file, optimizer_file
Parameters
----------
hypes : dict
Hyperparameters
Returns
-------
hypes, data_input, arch, objective, solver
"""
modules = {}
#base_path = hypes/
base_path = hypes['dirs']['base_path']
# _add_paths_to_sys(hypes)
#f=inputs/kitti_seg_input.py
f = os.path.join(base_path, hypes['model']['input_file'])
#将kitti_seg_input.py 导入到 input_%s,模块名data_input
data_input = imp.load_source("input_%s" % postfix, f)
modules['input'] = data_input
#kitti_seg_input.py modules['input']
f = os.path.join(base_path, hypes['model']['architecture_file'])
arch = imp.load_source("arch_%s" % postfix, f)
modules['arch'] = arch
#fcn8_vgg.py modules['arch']
f = os.path.join(base_path, hypes['model']['objective_file'])
objective = imp.load_source("objective_%s" % postfix, f)
modules['objective'] = objective
#kitti_multiloss.py modules['objective']
f = os.path.join(base_path, hypes['model']['optimizer_file'])
solver = imp.load_source("solver_%s" % postfix, f)
modules['solver'] = solver
#generic_optimazer.py modules['solver']
f = os.path.join(base_path, hypes['model']['evaluator_file'])
eva = imp.load_source("evaluator_%s" % postfix, f)
modules['eval'] = eva
#kitti_eval.py modules['eval']
return modules
modules['input'] kitti_seg_input.py
modules['arch'] fcn8_vgg.py
modules['objective'] kitti_multiloss.py
modules['solver'] generic_optimazer.py
modules['eval'] kitti_eval.py
with tf.Session() as sess:
#运行默认图
with tf.name_scope("Queues"):
queue = modules['input'].create_queues(hypes, 'train')
在inputs/kitti_seg_input.py下面找 create_queues()
def create_queues(hypes, phase):
"""Create Queues."""
arch = hypes['arch']
dtypes = [tf.float32, tf.int32]
shape_known = hypes['jitter']['reseize_image'] \
or hypes['jitter']['crop_patch']
if shape_known:
if hypes['jitter']['crop_patch']:
height = hypes['jitter']['patch_height']
width = hypes['jitter']['patch_width']
else:
height = hypes['jitter']['image_height']
width = hypes['jitter']['image_width']
channel = hypes['arch']['num_channels']
num_classes = hypes['arch']['num_classes']
shapes = [[height, width, channel],
[height, width, num_classes]]
else:
shapes = None
capacity = 50
q = tf.FIFOQueue(capacity=50, dtypes=dtypes, shapes=shapes)
tf.summary.scalar("queue/%s/fraction_of_%d_full" %
(q.name + "_" + phase, capacity),
math_ops.cast(q.size(), tf.float32) * (1. / capacity))
return q
#改reseize_image 为true
shapes =[[384,1248,3],[384,1248,2]]
#使用tf.FIFOQueue类创建一个先入先出队列.
#capacity:指定队列中的元素数量的上限。
#dtypes:DType对象的列表。dtypes的长度必须等于每个队列元素中张量的数量。
#shapes:(可选项)
#标量数据汇总和记录使用tf.summary.scala
gou
with tf.name_scope("Inputs"):
image, labels = data_input.inputs(hypes, queue, phase='train')
train 数据 输入图
构建
def _make_data_gen(hypes, phase, data_dir):
"""Return a data generator that outputs image samples.
@ Returns
image: integer array of shape [height, width, 3].
Representing RGB value of each pixel.
gt_image: boolean array of shape [height, width, num_classes].
Set `gt_image[i,j,k] == 1` if and only if pixel i,j
is assigned class k. `gt_image[i,j,k] == 0` otherwise.
[Alternativly make gt_image[i,j,*] a valid propability
distribution.]
"""
if phase == 'train':
data_file = hypes['data']["train_file"]
elif phase == 'val':
data_file = hypes['data']["val_file"]
else:
assert False, "Unknown Phase %s" % phase
data_file = os.path.join(data_dir, data_file)
road_color = np.array(hypes['data']['road_color'])
background_color = np.array(hypes['data']['background_color'])
data = _load_gt_file(hypes, data_file)
for image, gt_image in data:
gt_bg = np.all(gt_image == background_color, axis=2)
gt_road = np.all(gt_image == road_color, axis=2)
assert(gt_road.shape == gt_bg.shape)
shape = gt_bg.shape
gt_bg = gt_bg.reshape(shape[0], shape[1], 1)
gt_road = gt_road.reshape(shape[0], shape[1], 1)
gt_image = np.concatenate((gt_bg, gt_road), axis=2)
if phase == 'val':
yield image, gt_image
elif phase == 'train':
yield jitter_input(hypes, image, gt_image)
yield jitter_input(hypes, np.fliplr(image), np.fliplr(gt_image))
data_file = os.path.join(data_dir, data_file)
#data_file = DATA/road_data/train3.txt
data = _load_gt_file(hypes, data_file)
base_path = os.path.realpath(os.path.dirname(data_file))
def run_training(hypes, modules, tv_graph, tv_sess, start_step=0):