#------省略了准备数据步骤
# 优化器
optimizer = tf.keras.optimizers.SGD(learning_rate=1e-3)
# 损失函数
loss_fn = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
# 准备metrics函数
train_acc_metric = tf.keras.metrics.CategoricalAccuracy()
val_acc_metric = tf.keras.metrics.CategoricalAccuracy()
# 准备训练数据集
batch_size = 64
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)
# 准备测试数据集
val_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val))
val_dataset = val_dataset.batch(64)
#开始训练
model = MyModel(num_classes=10)
epochs = 3
for epoch in range(epochs):
# 遍历数据集的batch_size
for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
# 打开GradientTape以记录正向传递期间运行的操作,这将启用自动区分。
with tf.GradientTape() as tape:
logits = model(x_batch_train)
loss_value = loss_fn(y_batch_train, logits)
grads = tape.gradient(loss_value, model.trainable_weights) #计算梯度
optimizer.apply_gradients(zip(grads, model.trainable_weights))# 更新参数
# 更新训练集的metrics
train_acc_metric(y_batch_train, logits)
# 在每个epoch结束时显示metrics。
train_acc = train_acc_metric.result()
print(float(train_acc))
print(float(loss_value))
# 在每个epoch结束时重置训练指标
train_acc_metric.reset_states()#一定要重置!
# 在每个epoch结束时运行一个验证集。
for x_batch_val, y_batch_val in val_dataset:
val_logits = model(x_batch_val)
val_acc_metric(y_batch_val, val_logits)
val_acc = val_acc_metric.result()
print('Validation acc: %s' % (float(val_acc),))
val_acc_metric.reset_states()#和上面一样需要充值
Tensorflow自定义训练(不使用compile,fit)
最新推荐文章于 2022-11-15 20:00:38 发布