0.前言
本文主要介绍XGBoost中数据加载过程,主要是DMatrix::Load
内容。
1. XGBoost数据加载
1.1 DMatrix::Load主流程
数据集加载语句为:
std::shared_ptr<DMatrix> dtrain(DMatrix::Load(param.train_path, param.silent != 0, param.dsplit == 2));
# 函数原型为:DMatrix* DMatrix::Load(const std::string& uri, bool silent, bool load_row_split, const std::string& file_format)
可见,DMatrix::Load
返回为DMatrix
对象指针,参数传入为:1)文件URI
地址,2)silent
开关,为true打印统计信息,3)load_row_split
为分布式开关,为true,则对数据进行分片shard。4)file_format
为数据解析格式,默认参数为"auto"
,自动解析文件数据格式。基于mushroom.conf
配置, 提供训练数据地址:xgboost目录下demo/data/agaricus.txt.train
,数据为libsvm
格式。DMatrix::Load
主流程如下:
1. uri带有#符号,解析出cache_file
文件,考虑分布式模式的情况;
2. 分布式模式,获取当前主机排序partid
以及主机总数npart
,单机下partid=1;npart=1
;
3. 满足file_format == "auto" && npart == 1
,检测文件是否为二进制数据文件,检查开头魔方是否为SimpleCSRSource::kMagic
,若是则LoadBinary
直接初始化SimpleCSRSource
数据源对象source
。
4. 构建解析器parser
,基于工厂设计模式,基类dmlc::Parser<uint32_t>
调用静态方法Create()
, 尽管格式解析为”auto”,目前使用libsvm
格式解析。
5. 解析器parser
构建返回DMatrix
对象:dmat = DMatrix::Create(parser.get(), cache_file)
,核心过程,后续详解。
6. 根据参数slient打印相关统计信息,尝试读取.group
后缀的文件;.base_margin
后缀的boost初始设定值;.weight
样本的权重,用于代价敏感学习,不存在则跳过。
1.2 解析器parser构建过程
该小节先对步骤4进行详解,Parser<uint32_t>
继承于DataIter<RowBlock<uint32_t> >
,实际上内部数据以CSR格式
。基于静态方法Parser<uint32_t>::Create()
构建实例,该方法调用普通方法CreateParser_
方法。由于配置最后解析为libsvm
格式,使用ParserFactoryReg
工厂来找到对应的LibSVMParser
解析器对象。具体代码参考如下:
// 通过ptype=libsvm找到已注册的解析器工厂方法,Get()->Find()后面会看到
const ParserFactoryReg<IndexType>* e =
Registry<ParserFactoryReg<IndexType> >::Get()->Find(ptype);
// 通过工厂方法生成解析器对象,最终调用的是CreateLibSVMParser<uint32_t>函数
return (*e->body)(spec.uri, spec.args, part_index, num_parts);
理解上述的代码需要梳理以下的类定义与宏定义过程:
ParserFactoryReg解析器工厂类定义
// 工厂模式,通过宏定义来注册组件,继承FunctionRegEntryBase
// 第1个类参数必须是本身类型,第2个类参数是工厂方法。
// ParserFactoryReg会绑定工厂方法类型,通过工厂方法来组件对象,即Parser<IndexType>::Factor为函数类型
// 后期会调用DMLC_REGISTRY_ENABLE实例生成静态单例工厂对象。宏定义实现可扩展的工厂模式+对象单例非常精彩
template<typename IndexType>
struct ParserFactoryReg : public FunctionRegEntryBase
<ParserFactoryReg<IndexType>, typename Parser<IndexType>::Factory> {};
FunctionRegEntryBase工厂基类模板
// FunctionRegEntryBase方法注册类模板,需要注册项类型,方法类型
template<typename EntryType, typename FunctionType>
class FunctionRegEntryBase {
public:
std::string name; // 注册项名字
std::string description; // 注册项描述
std::vector<ParamFieldInfo> arguments;