train()
def train_one_epoch(sess, ops, train_writer, stack):
log_string(str(datetime.now()))
update_progress(0)
loss_sum = 0
confusion_matrix = metric.ConfusionMatrix(NUM_CLASSES)
在训练代码中增加了datetime时间显示和加载进度条的用法,更加人性化
1.datetime处理日期和时间
datetime.now() #获取当前时间
from datetime import datetime
now=datetime.now()
print('当前时间:',now)
utcnow=datetime.utcnow()
print('世界标准时间:',utcnow)
结果如下:
这常常用于在训练预测过程中提供具体的时间数据。
2.update_progress()
def update_progress(progress):
barLength = 10 # Modify this to change the length of the progress bar
if isinstance(progress, int):
progress = round(float(progress), 2)
if not isinstance(progress, float):
progress = 0
if progress < 0:
progress = 0
if progress >= 1:
progress = 1
block = int(round(barLength * progress))
text = "\rProgress: [{}] {}%".format(
"#" * block + "-" * (barLength - block), progress * 100 )
sys.stdout.write(text)
sys.stdout.flush()
进度条显示函数定义
update_progress(0)
progress = float(batch_idx) / float(num_batches)
update_progress(round(progress, 2))
update_progress(1)
以上代码以后通用,可借鉴。
train_one_peoch()
# Train over num_batches batches
for batch_idx in range(num_batches):
# Refill more batches if empty
progress = float(batch_idx) / float(num_batches)
update_progress(round(progress, 2))
batch_data, batch_label, batch_weights = stack.get()
重点弄清楚,网络训练数据batch_date,batch_label,batch_weight是如何生成的。
def train():
with tf.Graph().as_default():
stacker, stack_validation, stack_train = init_stacking()
此时已经在/dateset/semantic_downsampled/下保存了采样好(经过排除0标签点和体素下采样)的pcd文件以及对应的label文件。经过init_stacking()初始化堆叠
def init_stacking():
with tf.device("/cpu:0"):
# Queues that contain several batches in advance
num_train_batches = TRAIN_DATASET.get_num_batches(PARAMS["batch_size"]) #152
""" 将所有训练数据集的点全部累加 / (批次尺寸 * 8192)得到单次epoch的批次数量 """
num_validation_batches = VALIDATION_DATASET.get_num_batches(
PARAMS["batch_size"]
)
stack_train = mp.Queue(num_train_batches)
stack_validation = mp.Queue(num_validation_batches)
stacker = mp.Process(
target=fill_queues,
args=(stack_train,stack_validation,num_train_batches,num_validation_batches,))
stacker.start()
return stacker, stack_validation, stack_train
mp.Queue(152)创建了一个队列,可放入152个项目上限,一旦达到此大小,插入将被阻塞,直到消耗队列项目为止。执行Process函数生成多进程。
mp.Process(target=fill_queues,
args=(stack_train, #152个项目的队列
stack_validation, #148个项目的队列
num_train_batches, #152
num_validation_batches,)) #148
def fill_queues(
stack_train, stack_validation, num_train_batches, num_validation_batches):
pool = mp.Pool(processes=mp.cpu_count())
""" 创建Pool进程池对象,可指定数量的进程供用于使用"""
launched_train = 0
launched_validation = 0
results_train = [] # Temp buffer before filling the stack_train
results_validation = [] # Temp buffer before filling the stack_validation
# Launch as much as n
while True:
if stack_train.qsize() + launched_train < num_train_batches:
results_train.append(pool.apply_async(get_batch, args=("train",)))
""" 调用get_batch()函数,将训练集数据临时缓存在results_train"""
launched_train += 1
elif stack_validation.qsize() + launched_validation < num_validation_batches:
results_validation.append(pool.apply_async(get_batch, args=("validation",)))
launched_validation += 1
for p in results_train:
if p.ready():
stack_train.put(p.get())
results_train.remove(p)
launched_train -= 1
for p in results_validation:
if p.ready():
stack_validation.put(p.get())
results_validation.remove(p)
launched_validation -= 1
# Stability
time.sleep(0.01)
这里调用Queue(152),相当于 for i in range(num_batch): 完成一个epoch需要训练152个批次尺寸的数据。每次提取出16批次的8192点云数据batch_point,batch_label,batch_weight打包送进stack_train,每送一组数据,Queue.qsize()加一,直到送进152个批次数据。所以init_stacking()导出的stack_train()是包含了所有数据集的152个16批次的8192个点的数据。总共包含19922944个点的训练集。
调用get_batch()函数,返回TRAIN_DATASET.sample_batch_in_all_files()
sample_in_all_files()
从所有训练集中乱序提取一个数据集。
利用sample()
,读取该数据集的点坐标数据.points,从点云中随机选取一点作为中心点center_point。
算出场景z的范围 = 最大Z坐标 - 最小Z坐标。
确定包围盒范围,box_min盒最小坐标,box_max盒最大坐标。
mask = (np.sum((self.points[i_min:i_max, :] >= box_min)
* (self.points[i_min:i_max, :] <= box_max),axis=1,)== 3 )
因为所有的point都按x排序过,所以找出box_max[0]以及box_min[0]在点云中的索引,提取处所有满足小于最大坐标,大于最小坐标的所有点索引mask。
mask = np.hstack((
np.zeros(i_min, dtype=bool),mask,
np.zeros(len(self.points) - i_max, dtype=bool),))
构建在包围盒内的全局索引
points = points[scene_extract_mask]
labels = self.labels[scene_extract_mask]
colors = self.colors[scene_extract_mask]
sample_mask = self._get_fix_sized_sample_mask(points, num_points_per_sample)
points = points[sample_mask]
labels = labels[sample_mask]
colors = colors[sample_mask]
根据索引提取包围盒点云,再_get_fix_sized_sample_mask()
随机获取8192个索引,从包围盒点云中提取8192个点云。对应的labels和colors。返回的是随机数据集随机包围盒中的8192个点的points,labels,colors。
随机获取16个随机数据集的随机8192个点云信息作为batch_data,batch_label,batch_weight。
for _ in range(batch_size):
points, labels, colors, weights = self.sample_in_all_files(is_training=True)
if self.use_color:
batch_data.append(np.hstack((points, colors)))
else:
batch_data.append(points)
batch_label.append(labels)
batch_weights.append(weights)
batch_data = np.array(batch_data)
batch_label = np.array(batch_label)
batch_weights = np.array(batch_weights)
将所有16个batch的点云数据全部append进batch_data,batch_label,batch_weights。