本文介绍Keras一些常见的验证和调参技巧,快速地验证模型和调节超参。
小技巧:
- CSV数据文件加载
- Dense初始化警告
验证与调参:
- 模型验证(Validation)
- K重交叉验证(K-fold Cross-Validation)
- 网格搜索验证(Grid Search Cross-Validation)
Keras
CSV数据文件加载
使用NumPy的 loadtxt() 方法加载CSV数据文件
- delimiter:数据单元的分割符;
- skiprows:略过首行标题;
Dense初始化警告
Dense初始化参数的警告:
将init参数替换为 kernel_initializer 参数即可。
模型验证
在 fit() 中 自动 划分验证集:
通过设置参数 validation_split 的值(0~1)确定验证集的比例。
实现:
在 fit() 中 手动 划分验证集:
train_test_split 来源sklearn.model_selection:
- test_size :验证集的比例;
- random_state :随机数的种子;
通过参数 validation_data 添加验证数据,格式是 数据+标签 的元组。
实现:
交叉验证
K重交叉验证(K-fold Cross-Validation)是常见的模型评估统计。
人工模式
交叉验证函数 StratifiedKFold() 来源于sklearn.model_selection:
- n_splits :交叉的重数,即N重交叉验证;
- shuffle :数据和标签是否随机洗牌;
- random_state :随机数种子;
- skf.split(X, y) :划分数据和标签的索引。
cvscores用于统计K重交叉验证的结果,计算均值和方差。
实现:
输出:
Wrapper模式
通过 cross_val_score() 函数集成模型和交叉验证逻辑。
- 将模型封装成wrapper,注意使用 内置函数 ,而 非 调用,没有括号 () 。
- epochs 即轮次, batch_size 即批次数;
- StratifiedKFold是K重交叉验证的逻辑;
cross_val_score 的输入是模型wrapper、数据X、标签Y、交叉验证cv;输出是每次验证的结果,再计算均值和方差。
实现:
输出:
网格搜索验证
网格搜索验证(Grid Search Cross-Validation)用于选择模型的最优超参值。
交叉验证函数 GridSearchCV() 来源于sklearn.model_selection:
- 设置超参列表,如optimizers、 init_modes 、epochs、batches;
- 创建参数字典,key值是模型的参数,或者wrapper的参数;
- estimator是模型, param_grid 是网格参数字典, n_jobs 是进程数;
- 输出最优结果和其他排列组合结果。
实现:
输出: