Tensorflow官方文档:Writing a training loop from scratch

background:Tensorflow的模型API一般流程是

        build(通过functional API或者sequential模型) - compile(specify loss, optimizer, metrics) - fit

这篇文章主要介绍是compile和fit的底层api代码。直接上核心部分代码

optimizer = keras.optimizers.SGD(learning_rate=1e-3)
loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)

epochs = 2
for epoch in range(epochs):
    print("\nStart of epoch %d" % (epoch,))

    # Iterate over the batches of the dataset.
    for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
        
        with tf.GradientTape() as tape:

            logits = model(x_batch_train, training=True)  # Logits for this minibatch
            loss_value = loss_fn(y_batch_train, logits)

        grads = tape.gradient(loss_value, model.trainable_weights)
        optimizer.apply_gradients(zip(grads, model.trainable_weights))

        # Log every 200 batches.
        if step % 200 == 0:
            print(
                "Training loss (for one batch) at step %d: %.4f"
                % (step, float(loss_value))
            )
            print("Seen so far: %s samples" % ((step + 1) * batch_size))

 思路是对epochs进行for循环:

        在每一个epoch中,遍历数据batches:

                对于每一个batch数据,在with gradientTape下计算当前batch的损失函数,然后在gradientTape外计算损失函数对于model.trainable_weights的导数,然后用optimizer来update导数:

optimizer.apply_gradients(zip(grads, model.trainable_weights))

这里其实也还是用了tf封装后的函数,如果不用optimizer和trainable_weights的话,就要具体创建模型的参数w, b等,计算导数后用w.assign_sub来update参数

示例代码中,每200个epoch输出当前epoch的损失,那如果我想计算整个training epoch的损失怎么计算呢?

        1. 用list存储每个batch的loss值,在epoch结束时计算均值

        2. 或者,在epoch外创建一个tf.metrics object, 例如tf.metrics.Mean(),在每个loop下利用reset_state(), update_state()和result()方法 (我后续发现这个在前一个tutorial里面customize waht happens in Model.fit也有介绍)

Metrics使用

metrics使用核心思想就是建立tf.metrics对象,然后在每个batch的forward pass (也就是用模型计算一遍input)后update_state(),在epoch结束时候打印结果result(),同时reset_state()

对于training metrics,metrics的计算可以在forward-backward的过程中计算,对于validation metrics,也可以用batch同样方法计算,但是我觉得也可以直接不batch对整个数据集一次性计算

Graph Execution

文章介绍了利用@tf.function的magic将每个batch/step中的代码变成graph execution,相对于default的eager execution可以提高优化性能 (具体为什么没有详细说,我猜是对batch中每一个数据单元进行并行运算的时候,不用等到所有的当前步骤执行完再开始下一步骤,例如我再算b的loss_fn的时候可以同时进行a的gradient运算)

@tf.function
def train_step(x, y):
    with tf.GradientTape() as tape:
        logits = model(x, training=True)
        loss_value = loss_fn(y, logits)
    grads = tape.gradient(loss_value, model.trainable_weights)
    optimizer.apply_gradients(zip(grads, model.trainable_weights))
    train_acc_metric.update_state(y, logits)
    return loss_value

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值