回顾与本文目标
前一篇文章(第 4 篇)中,我们在命令行接口模式下运行 XGBoost 成功。在第 3 篇中,我们在 Python 接口下运行了 XGBoost,并且分析了其中的 PredictRaw
函数。但这里的代码读得有点乱。所以我会想到打算从头开始,一个一个函数解释代码是如何运行的。但事后发现,这件事情是不是特别容易的一件事情,如果强行这样做会让整篇文章显得异常冗长,从而丧失了写文章的意义。
所以,我打算把自己理解下来的部分先贴在前面。分割线之后的是细节部分。如果需要用到再去查阅即可。
小结
整个训练的过程的 大致逻辑 如下:
- 首先需要定义问题:训练数据集是什么?测试数据集是什么?中间评判的准则是什么?训练时的目标函数是什么?迭代更新的参数是什么?(这两个问题决定了迭代法求解优化问题的具体格式。)这些设置都存于配置文件中,程序第一步是读取 配置 (conf),然后转存为 参数 (param)
- 根据参数,把该读的数据文件载入 学习器 (leaner),然后初始化模型,开始准备训练。
- 由于模型是生长模式的,生长的次数是
num_round
。生长出的模型可以是线性模型(GBLinear
)也可以是决策树(GBTree
)无论如何,外层循环是使得模型不断生长的循环,在CLITrain()
函数中调用。 - 每一个循环的实现者是
UpdateOneIter()
。这里的一个循环执行完之后,模型数量就会 +1 (对一个输出的情形,如有 N 个输出,模型数量会 +N,to be verified, 没有试过,欢迎批评指正)。主要做三步,第一步:求模型当前的函数值。第二步,求一阶导数、二阶导数。第三步,生长模型。注: 这里容易让人费解的是代码中的函数名:
grad
和hess
看上去像梯度和 Hessian。第一反应很有可能是:诶,这两个东西,一个是向量一个是矩阵?对 N N N 个数据点和 M M M 个输出,莫非这两个变量是三阶、四阶张量?但实际上,是目标函数关于第二个变量 y ^ \hat y y^ 的一阶导数和二阶导数,都是标量,结合上数据点和输出的维度,两个量都是矩阵。 - 对树模型,第三步主要是找最适当的分割点。实践中,算法主要实现了 Algorithm 3 的算法(而非)
注: 这里容易让人不解的是分割 Algorithm 2。默认情况下没有执行这个算法。
一些实现时的一开始不太让人理解的技巧:
- 分割点存于
sindex_
中。这个变量有点变态,定义在RegTree
里,是一个unsigned
类型的变量。这个变量实际上复合了两个信息。第一是分割的输入向量(或者叫特征向量)的序号。由于unsigned
是 32 位的,这里实际上用了后面 31 位,保存这个序号。而第一位是用来保存 左分割 还是 右分割。 - 在树模型中,每一棵树的叶片最终算出一个得分。所有得分累加,得出总得分。如果是回归问题,那就可以直接和真实结果进行比对。如果是二元分类问题,那还要在最外层作用 sigmoid 函数,最后再计算交叉熵得到目标函数(误差)。(这实际上应该是 logistic 回归的基础知识,但放在这个语境下,有时候会忘掉目标函数是怎么被求出来的)。
下面还能看看的,可能只有两张图了,请看官随意翻阅。
谢谢观赏!
起点 main()
@ src/cli_main.cc
int main(int argc, char *argv[]) {
return xgboost::CLIRunTask(argc, argv);
}
这是程序的入口。作为我们的例子,argc
显然是 2,argv
是两个字符串:xgboost
文件的路径全称和配置文件 mushroom.conf
。这个函数会调用 CLIRunTask()
就在主函数上方。
通过命令行接口执行任务 CLIRunTask()
@ src/cli_main.cc
int CLIRunTask(int argc, char *argv[]) {
if (argc < 2) {
printf("Usage: <config>\n");
return 0;
}
rabit::Init(argc, argv);
std::vector<std::pair<std::string, std::string> > cfg;
cfg.emplace_back("seed", "0");
common::ConfigIterator itr(argv[1]);
while (itr.Next()) {
cfg.emplace_back(std::string(itr.Name()), std::string(itr.Val()));
}
for (int i = 2; i < argc; ++i) {
char name[256], val[256];
if (sscanf(argv[i], "%[^=]=%s", name, val) == 2) {
cfg.emplace_back(std::string(name), std::string(val));
}
}
CLIParam param;
param.Configure(cfg);
switch (param.task) {
case kTrain: CLITrain(param); break;
case kDumpModel: CLIDumpModel(param); break;
case kPredict: CLIPredict(param); break;
}
rabit::Finalize();
return 0;
}
这个函数甚至于可以直接和主函数合并。如果写出它的中文伪代码,可以是:
- 如果输入没有参数,那么就打印一句提示,并且直接返回。
- 调用并行环境的构造函数
- 定义 配置变量,类似于一个字典
- 第一个成员:
{种子:零}
1。 - 再定义一个 配置循环体,以 配置文件名 初始化。
- 做循环,直到配置循环体没有下一个,把信息(名字、值)都传入配置变量
- 然后把文件外配置存入配置变量中。
- 第一个成员:
- 定义 参数变量
- 把配置变量存入参数变量
- 根据参数中的任务值执行任务,有三种选择:训练
0
[TODO]、导出1
、预测2
- 最后调用并行环境的析构函数
在默认的情况下,param
的task
是零,因此对应训练。
命令行训练任务 CLITrain()
@ src/cli_main.cc
void CLITrain(const CLIParam& param) {
const double tstart_data_load = dmlc::GetTime();
if (rabit::IsDistributed()) {
std::string pname = rabit::GetProcessorName();
LOG(CONSOLE) << "start " << pname << ":" << rabit::GetRank();
}
// load in data.
std::shared_ptr<DMatrix> dtrain(
DMatrix::Load(
param.train_path,
ConsoleLogger::GlobalVerbosity() > ConsoleLogger::DefaultVerbosity(),
param.dsplit == 2));
std::vector<std::shared_ptr<DMatrix> > deval;
std::vector<std::shared_ptr<DMatrix> > cache_mats;
std::vector<DMatrix*> eval_datasets;
cache_mats.push_back(dtrain);
for (size_t i = 0; i < param.eval_data_names.size(); ++i) {
deval.emplace_back(
std::shared_ptr<DMatrix>(DMatrix::Load(
param.eval_data_paths[i],
ConsoleLogger::GlobalVerbosity() > ConsoleLogger::DefaultVerbosity(),
param.dsplit == 2)));
eval_datasets.push_back(deval.back().get());
cache_mats.push_back(deval.back());
}
std::vector<std::string> eval_data_names = param.eval_data_names;
if (param.eval_train) {
eval_datasets.push_back(dtrain.get());
eval_data_names.emplace_back("train");
}
// initialize the learner.
std::unique_ptr<Learner> learner(Learner::Create(cache_mats));
int version = rabit::LoadCheckPoint(learner.get());
if (version == 0) {
// initialize the model if needed.
if (param.model_in != "NULL") {
std::unique_ptr<dmlc::Stream> fi(
dmlc::Stream::Create(param.model_in.c_str(), "r"));
learner->Load(fi.get());
learner->Configure(param.cfg);
} else {
learner->Configure(param.cfg);
learner->InitModel();
}
}
LOG(INFO) << "Loading data: " << dmlc::GetTime() - tstart_data_load << " sec";
// start training.
const double start = dmlc::GetTime();
for (int i = version / 2; i < param.num_round; ++i) {
double elapsed = dmlc::GetTime() - start;
if (version % 2 == 0) {
LOG(INFO) << "boosting round " << i << ", " << elapsed << " sec elapsed";
learner->UpdateOneIter(i, dtrain.get());
if (learner->AllowLazyCheckPoint()) {
rabit::LazyCheckPoint(learner.get());
} else {
rabit::CheckPoint(learner.get());
}
version += 1;
}