TensorFlow 2.0 推荐的训练方式写法

日萌社

 

人工智能AI:Keras PyTorch MXNet TensorFlow PaddlePaddle 深度学习实战(不定时更新)


TensorFlow 2.0 推荐的训练方式写法

  • 构建训练模型与函数
# 构建模型
model = build_model(
    vocab_size = len(vocab),
    embedding_dim=embedding_dim,
    rnn_units=rnn_units,
    batch_size=BATCH_SIZE)

# 选择优化器
optimizer = tf.keras.optimizers.Adam()


# 编写带有装饰器@tf.function的函数进行训练
@tf.function
def train_step(inp, target):
    """
    :param inp: 模型输入
    :param tatget: 输入对应的标签
    """
    # 打开梯度记录管理器
    with tf.GradientTape() as tape:
        # 使用模型进行预测
        predictions = model(inp)
        # 使用sparse_categorical_crossentropy计算平均损失
        loss = tf.reduce_mean(
            tf.keras.losses.sparse_categorical_crossentropy(
                target, predictions, from_logits=True))
    # 使用梯度记录管理器求解全部参数的梯度
    grads = tape.gradient(loss, model.trainable_variables)
    # 使用梯度和优化器更新参数
    optimizer.apply_gradients(zip(grads, model.trainable_variables))
    # 返回平均损失
    return loss

  • 进行训练:
# 训练轮数
EPOCHS = 10

#进行轮数循环
for epoch in range(EPOCHS):
    # 获得开始时间
    start = time.time()
    # 初始化隐层状态
    hidden = model.reset_states()
    # 进行批次循环
    for (batch_n, (inp, target)) in enumerate(dataset):
        # 调用train_step进行训练, 获得批次循环的损失
        loss = train_step(inp, target)
        # 每100个批次打印轮数,批次和对应的损失
        if batch_n % 100 == 0:
            template = 'Epoch {} Batch {} Loss {}'
            print(template.format(epoch+1, batch_n, loss))

    # 每5轮保存一次检测点
    if (epoch + 1) % 5 == 0:
        model.save_weights(checkpoint_prefix.format(epoch=epoch))

    # 打印轮数,当前损失,和训练耗时
    print ('Epoch {} Loss {:.4f}'.format(epoch+1, loss))
    print ('Time taken for 1 epoch {} sec\n'.format(time.time() - start))


# 保存最后的检测点
model.save_weights(checkpoint_prefix.format(epoch=epoch))

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

あずにゃん

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值