大模型与普通机器学习的数据集划分对比及实践
目录
- 数据集划分基础概念
- 普通机器学习的数据划分方法
- 大模型场景的特殊挑战
- 典型场景案例与代码实现
- 4.1 普通机器学习案例(鸢尾花分类)
- 4.2 大模型场景案例(NLP/CV/多模态)
- 核心差异对比表
- 挑战与解决方案
- 总结
1. 数据集划分基础概念
- 训练集:模型学习的主要数据源
- 验证集:超参数调优与早停判断
- 测试集:最终性能评估(仅用一次)
- 关键原则:数据分布一致性、严格隔离测试集
2. 普通机器学习的数据划分方法
典型流程(以Scikit-learn为例):
from sklearn.model_selection import train_test_split
X, y = load_data()
X_train, X_temp, y_train, y_temp = train_test_split(
X, y, test_size=0.3, stratify=y, random_state=42)
X_val, X_test, y_val, y_test = train_test_split(
X_temp, y_temp, test_size=0.5, stratify=y_temp, random_state=42)
print(f"划分结果: 训练集{len(X_train)}, 验证集{len(X_val)}, 测试集{len(X_test)}")
3. 大模型场景的特殊挑战
关键差异点:
维度 | 普通机器学习 | 大模型 |
---|
数据规模 | MB~GB级 | TB~PB级 |
划分策略 | 随机抽样 | 时间划分/领域划分 |
验证集使用 | 多次使用 | 单次或少次评估 |
数据存储 | 内存加载 | 分布式存储+流式加载 |
典型场景 | 单任务 | 多任务/持续学习 |
4. 典型场景案例与代码实现
4.1 普通机器学习案例
场景:鸢尾花分类
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
iris = load_iris()
X, y = iris.data, iris.target
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, stratify=y, random_state=42)
X_val, X_test, y_val, y_test = train_test_split(
X_test, y_test, test_size=0.5, stratify=y_test, random_state=42)
print(f"训练集: {X_train.shape[0]} samples")
print(f"验证集: {X_val.shape[0]} samples")
print(f"测试集: {X_test.shape[0]} samples")
4.2 大模型场景案例
Case 1: NLP预训练(HuggingFace实践)
from datasets import load_dataset
dataset = load_dataset("wikipedia", "20220301.en", split='train', streaming=True)
train_ratio = 0.98
val_ratio = 0.01
test_ratio = 0.01
train_dataset = dataset.take(int(1e6 * train_ratio))
remaining = dataset.skip(int(1e6 * train_ratio))
val_dataset = remaining.take(int(1e6 * val_ratio))
test_dataset = remaining.skip(int(1e6 * val_ratio)).take(int(1e6 * test_ratio))
train_iter = iter(train_dataset)
Case 2: 时间序列数据(金融预测)
df = load_stock_data().sort_values('timestamp')
split_idx = int(len(df)*0.95)
train_df = df.iloc[:split_idx]
temp_df = df.iloc[split_idx:]
val_cutoff = temp_df['timestamp'].quantile(0.5)
val_df = temp_df[temp_df.timestamp <= val_cutoff]
test_df = temp_df[temp_df.timestamp > val_cutoff]
Case 3: 多模态数据(图文匹配)
import pandas as pd
from sklearn.model_selection import GroupShuffleSplit
metadata = pd.read_csv("image_text_pairs.csv")
splitter = GroupShuffleSplit(test_size=0.2, n_splits=1, random_state=42)
train_idx, test_idx = next(splitter.split(metadata, groups=metadata['sample_id']))
train_meta = metadata.iloc[train_idx]
test_meta = metadata.iloc[test_idx]
splitter_val = GroupShuffleSplit(test_size=0.5)
train_idx_val, val_idx = next(splitter_val.split(train_meta, groups=train_meta['sample_id']))
5. 核心差异对比表
操作项 | 普通模型 | 大模型 |
---|
划分比例 | 通常70-20-10 | 常为98-1-1(大数据场景) |
数据顺序 | 随机打乱 | 可能保留时序/领域顺序 |
数据存储格式 | CSV/NumPy数组 | TFRecord/Parquet/内存映射文件 |
划分执行位置 | 内存操作 | 分布式存储系统级操作 |
重复使用验证集 | 常见(如交叉验证) | 严格限制次数 |
6. 挑战与解决方案
常见问题:
- 数据分布偏移 → 领域自适应划分策略
- 内存限制 → 使用
datasets
库的磁盘映射功能 - 多模态对齐 → GroupShuffleSplit保持关联
- 评估成本高 → 小比例验证集+Proxy指标
优化代码示例(内存映射):
from datasets import load_from_disk
dataset = load_from_disk("path/to/dataset")
dataset = dataset.class_encode_column("label")
split = dataset.train_test_split(
test_size=0.01,
stratify_by_column="label",
seed=42
)
7. 总结
- 大模型划分核心:强调数据分布的时空一致性,而非简单随机划分
- 工程实践要点:
- 使用流式处理应对大数据规模
- 严格隔离测试集避免间接泄漏
- 采用领域/时间划分代替随机划分
- 未来趋势:自动化数据划分策略(AutoML方向)