本文对EAST算法中的重要模块进行梳理,包括多gpu训练、算法pipline、输入数据和输出数据的格式。
- 数据并行。使用多gpu进行训练,将batch的训练数据根据设定的gpu数量进行划分,每个gpu训练batch数据中的一部分,得到其结构风险,然后计算当前结构风险的梯度信息。当前batch数据的梯度全部计算完后,进行梯度更新。
for i, gpu_id in enumerate(gpus): #数据并行
with tf.device('/gpu:%d' % gpu_id):
with tf.name_scope('model_%d' % gpu_id) as scope:
iis = input_images_split[i]
isms = input_score_maps_split[i]
igms = input_geo_maps_split[i]
itms = input_training_masks_split[i]
total_loss, model_loss = tower_loss(iis, isms, igms, itms, reuse_variables)
batch_norm_updates_op = tf.group(*tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope))
reuse_variables = True
grads = opt.compute_gradients(total_loss)
tower_grads.append(grads) #汇总梯度
grads = average_gradients(tower_grads) #更新梯度
apply_gradient_op = opt.apply_gradients(grads, global_step=global_step)
summary_op = tf.summary.merge_all()
# save moving average
variable_averages = tf.train.ExponentialMovingAverage(
FLAGS.moving_average_decay, global_step)
variables_averages_op &#