日萌社
人工智能AI:Keras PyTorch MXNet TensorFlow PaddlePaddle 深度学习实战(不定时更新)
TensorFlow 2.0 推荐的训练方式写法
- 构建训练模型与函数
# 构建模型
model = build_model(
vocab_size = len(vocab),
embedding_dim=embedding_dim,
rnn_units=rnn_units,
batch_size=BATCH_SIZE)
# 选择优化器
optimizer = tf.keras.optimizers.Adam()
# 编写带有装饰器@tf.function的函数进行训练
@tf.function
def train_step(inp, target):
"""
:param inp: 模型输入
:param tatget: 输入对应的标签
"""
# 打开梯度记录管理器
with tf.GradientTape() as tape:
# 使用模型进行预测
predictions = model(inp)
# 使用sparse_categorical_crossentropy计算平均损失
loss = tf.reduce_mean(
tf.keras.losses.sparse_categorical_crossentropy(
target, predictions, from_logits=True))
# 使用梯度记录管理器求解全部参数的梯度
grads = tape.gradient(loss, model.trainable_variables)
# 使用梯度和优化器更新参数
optimizer.apply_gradients(zip(grads, model.trainable_variables))
# 返回平均损失
return loss
- 进行训练:
# 训练轮数
EPOCHS = 10
#进行轮数循环
for epoch in range(EPOCHS):
# 获得开始时间
start = time.time()
# 初始化隐层状态
hidden = model.reset_states()
# 进行批次循环
for (batch_n, (inp, target)) in enumerate(dataset):
# 调用train_step进行训练, 获得批次循环的损失
loss = train_step(inp, target)
# 每100个批次打印轮数,批次和对应的损失
if batch_n % 100 == 0:
template = 'Epoch {} Batch {} Loss {}'
print(template.format(epoch+1, batch_n, loss))
# 每5轮保存一次检测点
if (epoch + 1) % 5 == 0:
model.save_weights(checkpoint_prefix.format(epoch=epoch))
# 打印轮数,当前损失,和训练耗时
print ('Epoch {} Loss {:.4f}'.format(epoch+1, loss))
print ('Time taken for 1 epoch {} sec\n'.format(time.time() - start))
# 保存最后的检测点
model.save_weights(checkpoint_prefix.format(epoch=epoch))