XGBoost (5) C++ 命令行训练源码分析

本文深入分析XGBoost C++命令行训练流程,包括模型初始化、数据读取、训练迭代、预测及决策树生长等关键步骤。详细解释了配置读取、数据加载、模型更新及评估过程,揭示了XGBoost训练背后的逻辑。
摘要由CSDN通过智能技术生成

回顾与本文目标

前一篇文章(第 4 篇)中,我们在命令行接口模式下运行 XGBoost 成功。在第 3 篇中,我们在 Python 接口下运行了 XGBoost,并且分析了其中的 PredictRaw 函数。但这里的代码读得有点乱。所以我会想到打算从头开始,一个一个函数解释代码是如何运行的。但事后发现,这件事情是不是特别容易的一件事情,如果强行这样做会让整篇文章显得异常冗长,从而丧失了写文章的意义。

所以,我打算把自己理解下来的部分先贴在前面。分割线之后的是细节部分。如果需要用到再去查阅即可。

小结

整个训练的过程的 大致逻辑 如下:

  1. 首先需要定义问题:训练数据集是什么?测试数据集是什么?中间评判的准则是什么?训练时的目标函数是什么?迭代更新的参数是什么?(这两个问题决定了迭代法求解优化问题的具体格式。)这些设置都存于配置文件中,程序第一步是读取 配置 (conf),然后转存为 参数 (param)
  2. 根据参数,把该读的数据文件载入 学习器 (leaner),然后初始化模型,开始准备训练。
  3. 由于模型是生长模式的,生长的次数是 num_round。生长出的模型可以是线性模型(GBLinear)也可以是决策树(GBTree)无论如何,外层循环是使得模型不断生长的循环,在 CLITrain() 函数中调用。
  4. 每一个循环的实现者是 UpdateOneIter()。这里的一个循环执行完之后,模型数量就会 +1 (对一个输出的情形,如有 N 个输出,模型数量会 +N,to be verified, 没有试过,欢迎批评指正)。主要做三步,第一步:求模型当前的函数值。第二步,求一阶导数、二阶导数。第三步,生长模型。

    : 这里容易让人费解的是代码中的函数名:gradhess 看上去像梯度和 Hessian。第一反应很有可能是:诶,这两个东西,一个是向量一个是矩阵?对 N N N 个数据点和 M M M 个输出,莫非这两个变量是三阶、四阶张量?但实际上,是目标函数关于第二个变量 y ^ \hat y y^ 的一阶导数和二阶导数,都是标量,结合上数据点和输出的维度,两个量都是矩阵。

  5. 对树模型,第三步主要是找最适当的分割点。实践中,算法主要实现了 Algorithm 3 的算法(而非)

    : 这里容易让人不解的是分割 Algorithm 2。默认情况下没有执行这个算法。

一些实现时的一开始不太让人理解的技巧:

  1. 分割点存于 sindex_ 中。这个变量有点变态,定义在 RegTree 里,是一个 unsigned 类型的变量。这个变量实际上复合了两个信息。第一是分割的输入向量(或者叫特征向量)的序号。由于 unsigned 是 32 位的,这里实际上用了后面 31 位,保存这个序号。而第一位是用来保存 左分割 还是 右分割
  2. 在树模型中,每一棵树的叶片最终算出一个得分。所有得分累加,得出总得分。如果是回归问题,那就可以直接和真实结果进行比对。如果是二元分类问题,那还要在最外层作用 sigmoid 函数,最后再计算交叉熵得到目标函数(误差)。(这实际上应该是 logistic 回归的基础知识,但放在这个语境下,有时候会忘掉目标函数是怎么被求出来的)。

下面还能看看的,可能只有两张图了,请看官随意翻阅。

谢谢观赏!


起点 main() @ src/cli_main.cc

int main(int argc, char *argv[]) {
   
  return xgboost::CLIRunTask(argc, argv);
}

这是程序的入口。作为我们的例子,argc 显然是 2,argv 是两个字符串:xgboost 文件的路径全称和配置文件 mushroom.conf。这个函数会调用 CLIRunTask() 就在主函数上方。


通过命令行接口执行任务 CLIRunTask() @ src/cli_main.cc

int CLIRunTask(int argc, char *argv[]) {
   
  if (argc < 2) {
   
    printf("Usage: <config>\n");
    return 0;
  }
  rabit::Init(argc, argv);

  std::vector<std::pair<std::string, std::string> > cfg;
  cfg.emplace_back("seed", "0");

  common::ConfigIterator itr(argv[1]);
  while (itr.Next()) {
   
    cfg.emplace_back(std::string(itr.Name()), std::string(itr.Val()));
  }

  for (int i = 2; i < argc; ++i) {
   
    char name[256], val[256];
    if (sscanf(argv[i], "%[^=]=%s", name, val) == 2) {
   
      cfg.emplace_back(std::string(name), std::string(val));
    }
  }
  CLIParam param;
  param.Configure(cfg);

  switch (param.task) {
   
    case kTrain: CLITrain(param); break;
    case kDumpModel: CLIDumpModel(param); break;
    case kPredict: CLIPredict(param); break;
  }
  rabit::Finalize();
  return 0;
}

这个函数甚至于可以直接和主函数合并。如果写出它的中文伪代码,可以是:

  1. 如果输入没有参数,那么就打印一句提示,并且直接返回。
  2. 调用并行环境的构造函数
  3. 定义 配置变量,类似于一个字典
    1. 第一个成员:{种子:零} 1
    2. 再定义一个 配置循环体,以 配置文件名 初始化。
    3. 做循环,直到配置循环体没有下一个,把信息(名字、值)都传入配置变量
    4. 然后把文件外配置存入配置变量中。
  4. 定义 参数变量
    1. 把配置变量存入参数变量
  5. 根据参数中的任务值执行任务,有三种选择:训练 0 [TODO]、导出 1、预测 2
  6. 最后调用并行环境的析构函数
    在默认的情况下,paramtask 是零,因此对应训练。

命令行训练任务 CLITrain() @ src/cli_main.cc

void CLITrain(const CLIParam& param) {
   
  const double tstart_data_load = dmlc::GetTime();
  if (rabit::IsDistributed()) {
   
    std::string pname = rabit::GetProcessorName();
    LOG(CONSOLE) << "start " << pname << ":" << rabit::GetRank();
  }
  // load in data.
  std::shared_ptr<DMatrix> dtrain(
      DMatrix::Load(
          param.train_path,
          ConsoleLogger::GlobalVerbosity() > ConsoleLogger::DefaultVerbosity(),
          param.dsplit == 2));
  std::vector<std::shared_ptr<DMatrix> > deval;
  std::vector<std::shared_ptr<DMatrix> > cache_mats;
  std::vector<DMatrix*> eval_datasets;
  cache_mats.push_back(dtrain);
  for (size_t i = 0; i < param.eval_data_names.size(); ++i) {
   
    deval.emplace_back(
        std::shared_ptr<DMatrix>(DMatrix::Load(
            param.eval_data_paths[i],
            ConsoleLogger::GlobalVerbosity() > ConsoleLogger::DefaultVerbosity(),
            param.dsplit == 2)));
    eval_datasets.push_back(deval.back().get());
    cache_mats.push_back(deval.back());
  }
  std::vector<std::string> eval_data_names = param.eval_data_names;
  if (param.eval_train) {
   
    eval_datasets.push_back(dtrain.get());
    eval_data_names.emplace_back("train");
  }
  // initialize the learner.
  std::unique_ptr<Learner> learner(Learner::Create(cache_mats));
  int version = rabit::LoadCheckPoint(learner.get());
  if (version == 0) {
   
    // initialize the model if needed.
    if (param.model_in != "NULL") {
   
      std::unique_ptr<dmlc::Stream> fi(
          dmlc::Stream::Create(param.model_in.c_str(), "r"));
      learner->Load(fi.get());
      learner->Configure(param.cfg);
    } else {
   
      learner->Configure(param.cfg);
      learner->InitModel();
    }
  }
  LOG(INFO) << "Loading data: " << dmlc::GetTime() - tstart_data_load << " sec";

  // start training.
  const double start = dmlc::GetTime();
  for (int i = version / 2; i < param.num_round; ++i) {
   
    double elapsed = dmlc::GetTime() - start;
    if (version % 2 == 0) {
   
      LOG(INFO) << "boosting round " << i << ", " << elapsed << " sec elapsed";
      learner->UpdateOneIter(i, dtrain.get());
      if (learner->AllowLazyCheckPoint()) {
   
        rabit::LazyCheckPoint(learner.get());
      } else {
   
        rabit::CheckPoint(learner.get());
      }
      version += 1;
    }
    
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值