其实是一个课程作业,要求实现 GBDT 算法。在实现的过程中参考了很多资料,也做了很多优化,觉得收获很大,因此把开发的过程也记录了下来。
源代码在 GitHub。
1. 构建与使用
1.1 构建
- Windows: 使用 Visual Studio 2017 打开解决方案并生成即可。
- Linux: 根目录提供了
makefile
文件,使用make
编译即可,需要gcc >= 5.4.0
1.2 使用
-
用法:
boost <config_file> <train_file> <test_file> <predict_dest>
-
接受 LibSVM 格式的训练数据输入,如下每行代表一个训练样本:
<label> <feature-index>:<feature-value> <feature-index>:<feature-value> <feature-index>:<feature-value> 复制代码
-
用于预测的数据输入和训练数据类似:
<id> <feature-index>:<feature-value> <feature-index>:<feature-value> <feature-index>:<feature-value> 复制代码
-
目前只支持二分类问题
-
<config_file>
指定训练参数:eta = 1. # shrinkage rate gamma = 0. # minimum gain required to split a node maxDepth = 6 # max depth allowed minChildWeight = 1 # minimum allowed size for a node to be splitted rounds = 1 # REQUIRED. number of subtrees subsample = 1. # subsampling ratio for each tree colsampleByTree = 1. # tree-wise feature subsampling ratio maxThreads = 1; # max running threads features; # REQUIRED. number of features validateSize = .2 # if greater than 0, input data will be split into two sets and used for training and validation repectively 复制代码
2. 算法原理
GBDT 的核心可以分成两部分,分别是 Gradient Boosting 和 Decision Tree:
- Decision Tree : GBDT 的基分类器,通过划分输入样本的特征使得落在相同特征的样本拥有大致相同的 label。由于在 GBDT 中需要对若干不同的 Decision Tree 的结果进行综合,因此一般采用的是 Regression Tree (回归树)而不是 Classification Tree (分类树)。
- Gradient Boosting: 迭代式的集成算法,每一棵决策树的学习目标 y 都是之前所有树的结论和的残差(即梯度方向),也即 。
3. 实现与优化历程
各个部分的实现均经过若干次“初版实现 - 性能 profiling - 优化得到下一版代码”的迭代。其中,性能 profiling 部分,使用的是 Visual Studio 2017 的“性能探查器”功能,在进行性能 profile 之前均使用 release 模式编译(打开/O2 /Oi
优化选项)。
3.1 数据处理
选择的输入文件数据格式是 Libsvm 的格式,格式如下:
<label> <feature-index>:<feature-value> <feature-index>:<feature-value>
复制代码
可以看到这种格式天然适合用来表示稀疏的数据集,但在实现过程中,为了简单起见以及 cache 性能,我通过将空值填充为 0 转化为密集矩阵形式存储。代价是内存占用会相对高许多。
3.1.1 初版
最初并没有做什么优化,采用的是如下的简单流程:
- 文件按行读取
- 对于每一行内容,先转成
std::stringstream
,再从中解析出相应的数据。
核心代码如下:
ifstream in(path);
string line;
while (getline(in, line)) {
auto item = parseLibSVMLine(move(line), featureCount); // { label, vector }
x.push_back(move(item.first));
y.push_back(item.second);
}
/* in parseLibSVMLine */
stringstream ss(line);
ss >> label;
while (ss) {
char _;
ss >> index >> _ >> value;
values[index - 1] = value;
}
复制代码
profile 结果:
可以看到,主要的耗时在于将一行字符串解析成我们需要的 label + vector 数据这一过程中,进一步分析:
因此得知主要问题在于字符串解析部分。此时怀疑是 std::stringstream
的实现为了线程安全、错误检查等功能牺牲了性能,因此考虑使用 cstdio
中的实现。
3.1.2 改进
将 parseLibSVMLine
的实现重写,使用cstdio
中的sscanf
代替了 std::stringstream
:
int lastp = -1;
for (size_t p = 0; p < line.length(); p++) {
if (isspace(line[p]) || p == line.length() - 1) {
if (lastp == -1) {
sscanf(line.c_str(), "%zu", &label);
}
else {
sscanf(line.c_str() + lastp, "%zu:%lf", &index, &value);
values[index - 1] = value;
}
lastp = int(p + 1);
}
}
复制代码
profile 结果:
可以看到,虽然 parse 部分仍然是计算的热点,但这部分的计算量显著下降(53823 -> 23181),读取完整个数据集的是时间减少了 50% 以上。
3.1.3 最终版
显然,在数据集中,每一行之间的解析任务都是相互独立的,因此可以在一次性读入整个文件并按行划分数据后,对数据的解析进行并行化:
string content;
getline(ifstream(path), content, '\0');
stringstream in(move(content));
vector<string> lines;
string line;
while (getline(in, line)) lines.push_back(move(line));
#pragma omp parallel for
for (int i = 0; i < lines.size(); i++) {
auto item = parseLibSVMLine(move(lines[i]), featureCount);
#pragma omp c