模型训练
模型训练是机器学习和深度学习中调整模型参数以优化性能的过程,通常包括以下步骤:
- 数据准备:加载数据、预处理、分割训练集和测试集。
- 模型定义:选择合适的算法或网络架构。
- 损失函数与优化器:定义训练目标(损失函数)和优化算法。
- 训练过程:通过迭代更新模型参数,使模型在训练集上表现更优。
- 验证与测试:通过验证集或测试集评估模型的性能,避免过拟合。
模型训练的关键步骤
1. 数据准备
- 数据加载:加载原始数据(如 CSV、图像文件等)。
- 数据预处理:
- 特征归一化或标准化。
- 数据增强(图像翻转、裁剪等)。
- 划分数据集:将数据分为训练集、验证集和测试集。
示例代码:
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_iris
from sklearn.preprocessing import StandardScaler
# 加载数据
data = load_iris()
X, y = data.data, data.target
# 特征归一化
scaler = StandardScaler()
X = scaler.fit_transform(X)
# 划分数据集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
2. 定义模型
- 根据任务选择模型架构。
- 对于深度学习,可以定义神经网络的层次结构。
示例代码(TensorFlow):
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
# 创建一个简单的神经网络
model = Sequential([
Dense(64, activation='relu', input_shape=(4,)), # 输入特征为4维
Dense(32, activation='relu'),
Dense(3, activation='softmax') # 输出3个类别
])
3. 定义损失函数和优化器
- 损失函数:定义模型优化的目标。
- 回归:MSE、MAE。
- 分类:交叉熵、KL 散度。
- 优化器:选择适合的优化算法(如 SGD、Adam)。
示例代码:
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy', # 适合分类问题
metrics=['accuracy']) # 添加评估指标
4. 训练模型
- 批量训练:通过小批量数据更新参数,提升计算效率。
- 迭代训练:每轮称为一个 epoch,表示遍历整个数据集一次。
- 验证集评估:在每个 epoch 结束后,通过验证集评估模型性能。
示例代码:
history = model.fit(X_train, y_train,
validation_split=0.2, # 从训练集中划分验证集
epochs=50, # 训练50轮
batch_size=16, # 每批大小为16
verbose=1) # 输出训练日志
Epoch 1/50
6/6 [==============================] - 1s 37ms/step - loss: 1.2995 - accuracy: 0.3229 - val_loss: 1.1235 - val_accuracy: 0.1667
Epoch 2/50
6/6 [==============================] - 0s 6ms/step - loss: 1.1551 - accuracy: 0.3542 - val_loss: 1.0245 - val_accuracy: 0.5833
Epoch 3/50
6/6 [==============================] - 0s 6ms/step - loss: 1.0217 - accuracy: 0.5104 - val_loss: 0.9356 - val_accuracy: 0.7917
Epoch 4/50
6/6 [==============================] - 0s 6ms/step - loss: 0.9087 - accuracy: 0.7500 - val_loss: 0.8568 - val_accuracy: 0.8750
Epoch 5/50
6/6 [==============================] - 0s 6ms/step - loss: 0.8113 - accuracy: 0.8021 - val_loss: 0.7840 - val_accuracy: 0.8750
Epoch 6/50
6/6 [==============================] - 0s 5ms/step - loss: 0.7256 - accuracy: 0.8125 - val_loss: 0.7158 - val_accuracy: 0.8750
Epoch 7/50
6/6 [==============================] - 0s 6ms/step - loss: 0.6508 - accuracy: 0.8125 - val_loss: 0.6570 - val_accuracy: 0.8750
Epoch 8/50
6/6 [==============================] - 0s 6ms/step - loss: 0.5906 - accuracy: 0.8229 - val_loss: 0.6048 - val_accuracy: 0.8750
Epoch 9/50
6/6 [==============================] - 0s 6ms/step - loss: 0.5433 - accuracy: 0.8229 - val_loss: 0.5616 - val_accuracy: 0.8750
Epoch 10/50
6/6 [==============================] - 0s 7ms/step - loss: 0.5063 - accuracy: 0.8229 - val_loss: 0.5259 - val_accuracy: 0.8750
Epoch 11/50
6/6 [==============================] - 0s 6ms/step - loss: 0.4748