Tensorflow训练循环

 

训练循环

https://github.com/tensorflow/tensorflow/blob/r1.12/tensorflow/python/keras/engine/training_eager.py


def fit_loop(model,
             inputs,
             targets,
             sample_weights=None,
             class_weight=None,
             val_inputs=None,
             val_targets=None,
             val_sample_weights=None,
             batch_size=None,
             epochs=1,
             verbose=1,
             callbacks=None,
             shuffle=True,
             initial_epoch=0,
             steps_per_epoch=None,

  # Convert training inputs to an EagerIterator
  inputs, steps_per_epoch = training_utils.convert_to_iterator(
      x=inputs,
      y=targets,
      sample_weights=sample_weights,
      batch_size=batch_size,
      steps_per_epoch=steps_per_epoch,
      epochs=epochs,
      shuffle=shuffle)
    # 迭代每个周期
    for epoch in range(initial_epoch, epochs):
      iterator_fit_loop(
          model,
          inputs,
          class_weight,
          steps_per_epoch=steps_per_epoch,
          epoch_logs=epoch_logs,
          val_inputs=val_inputs,
          val_targets=val_targets,
          val_sample_weights=val_sample_weights,
          epochs=epochs,
          verbose=verbose,
          callbacks=callbacks,
          validation_steps=validation_steps,
          do_validation=do_validation,
          batch_size=batch_size)



一个迭代周期
def iterator_fit_loop(model,inputs,class_weight,steps_per_epoch,epoch_logs,val_inputs=None,val_targets=None,val_sample_weights=None,epochs=1,verbose=1,callbacks=None,validation_steps=None,do_validation=False,batch_size=None):
  # 在一个周期中,遍历训练 Dataset 中的每个样本,并获取样本的特征 (x) 和标签 (y)。
  for step_index in range(steps_per_epoch):
    next_element = inputs.get_next()
    x, y, sample_weights = next_element
    outs, loss, loss_metrics, masks = _process_single_batch(
        model, x, y, sample_weights=sample_weights, training=True)



def _process_single_batch(model, inputs, targets, sample_weights=None, training=False):
    outs, loss, loss_metrics, masks = _model_loss(model, inputs, targets, sample_weights=sample_weights, training=training)
    grads = tape.gradient(loss, model._collected_trainable_weights)
    model.optimizer.apply_gradients(zip(grads,model._collected_trainable_weights))


def _model_loss(model, inputs, targets, sample_weights=None, training=False):
    outs = model.call(inputs, **kwargs)
    for i, loss_fn in enumerate(model.loss_functions):
        weighted_masked_fn = training_utils.weighted_masked_objective(loss_fn)
        output_loss = weighted_masked_fn(targets[i], outs[i], weights, mask=mask)
        loss_metrics.append(backend.mean(output_loss))
        loss_weight = model.loss_weights_list[i]
        total_loss += loss_weight * output_loss
    return outs, total_loss, loss_metrics, masks


def _eager_metrics_fn(model, outputs, targets, sample_weights=None, masks=None):
  metric_results = model._handle_metrics(
      outputs, targets=targets, sample_weights=sample_weights, masks=masks)

 


1.迭代每个周期。通过一次数据集即为一个周期。


2.在一个周期中,遍历训练 Dataset 中的每个样本,并获取样本的特征 (x) 和标签 (y)。
  for step_index in range(steps_per_epoch):
    next_element = inputs.get_next()
    x, y, sample_weights = next_element


3.根据样本的特征进行预测,并比较预测结果和标签。衡量预测结果的不准确性,并使用所得的值计算模型的损失和梯度。

    outs, loss, loss_metrics, masks = _model_loss(model, inputs, targets, sample_weights=sample_weights, training=training)
    grads = tape.gradient(loss, model._collected_trainable_weights)


4.使用 optimizer 更新模型的变量。
    model.optimizer.apply_gradients(zip(grads,model._collected_trainable_weights))


5.跟踪一些统计信息以进行可视化。
    metrics_results = _eager_metrics_fn(
        model, outs, y, sample_weights=sample_weights, masks=masks)


6.对每个周期重复执行以上步骤。

 

 

 

 

 

 

 

 

 

 

 

 

 

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值