使用 TensorFlow 实现自定义训练循环(Custom Training Loop)
默认的
model.fit()
已足够应对大多数任务,但在一些复杂场景下,如多任务学习、自定义损失函数、梯度裁剪等,我们就需要更细粒度的控制 —— 这正是自定义训练循环的用武之地。
✨ 自定义训练循环的核心优势
- 更灵活的控制训练流程
- 支持复杂的模型结构与损失函数
- 可调试性更强(便于插入打印、日志记录等)
- 适合研究性、创新性项目
🧱 主要组成结构
- 前向传播(Forward pass)
- 损失计算(Loss computation)
- 反向传播(Backward pass)
- 权重更新(Weights update)
- 评估与日志记录(Evaluation and logging)
📦 1. 准备工作
导入必要的库并生成一个简单的数据集:
import tensorflow as tf
import numpy as np
# 创建模拟数据集
x = np.random.rand(1000, 3).astype(np.float32)
y = 3 * x[:, 0] + 2 * x[:, 1] - x[:, 2] + 0.5 + np.random.randn(1000).astype(np.float32) * 0.1
🏗 2. 定义模型与优化器
# 定义一个简单的全连接模型
class MyModel(tf.keras.Model):
def __init__(self):
super().__init__()
self.dense1 = tf.keras.layers.Dense(10, activation='relu')
self.out = tf.keras.layers.Dense(1)
def call(self, inputs):
x = self.dense1(inputs)
return self.out(x)
model = MyModel()
optimizer = tf.keras.optimizers.Adam(learning_rate=0.01)
🎯 3. 定义损失函数与评估指标
loss_fn = tf.keras.losses.MeanSquaredError()
train_loss = tf.keras.metrics.Mean(name='train_loss')
🔁 4. 构建训练循环
@tf.function
def train_step(x_batch, y_batch):
with tf.GradientTape() as tape:
predictions = model(x_batch, training=True)
loss = loss_fn(y_batch, predictions)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
train_loss.update_state(loss)
🧪 5. 训练过程
BATCH_SIZE = 32
EPOCHS = 10
dataset = tf.data.Dataset.from_tensor_slices((x, y)).shuffle(1000).batch(BATCH_SIZE)
for epoch in range(EPOCHS):
train_loss.reset_state()
for x_batch, y_batch in dataset:
train_step(x_batch, y_batch)
print(f"Epoch {epoch + 1}, Loss: {train_loss.result():.4f}")
📈 6. 验证效果
# 取一个样本预测
test_sample = np.array([[0.5, 0.3, 0.2]])
prediction = model(test_sample)
print(f"预测结果: {prediction.numpy().flatten()[0]:.4f}")
🧠 进阶内容(可选)
你可以进一步在训练循环中加入:
- 自定义正则项
- 多损失函数加权
- 梯度裁剪:
tf.clip_by_norm()
- 学习率调度器:
tf.keras.optimizers.schedules
- 日志记录与 TensorBoard 可视化
✅ 总结
内容 | 方法 |
---|---|
前向传播 | model(x_batch) |
计算损失 | loss_fn(y_batch, prediction) |
反向传播 | tape.gradient(...) |
更新参数 | optimizer.apply_gradients(...) |
📌 使用建议
情况 | 建议使用 |
---|---|
训练逻辑简单 | model.fit() |
需要自定义控制 | Custom Training Loop |
复杂损失、优化逻辑 | 推荐使用自定义循环 |
研究性任务 / 高度定制 | 必须使用自定义循环 |