XGBoost解析系列-数据加载


0.前言

  本文主要介绍XGBoost中数据加载过程,主要是DMatrix::Load内容。

1. XGBoost数据加载

1.1 DMatrix::Load主流程

  数据集加载语句为:

std::shared_ptr<DMatrix> dtrain(DMatrix::Load(param.train_path, param.silent != 0, param.dsplit == 2));

# 函数原型为:DMatrix* DMatrix::Load(const std::string& uri, bool silent, bool load_row_split, const std::string& file_format) 

  可见,DMatrix::Load返回为DMatrix对象指针,参数传入为:1)文件URI地址,2)silent开关,为true打印统计信息,3)load_row_split为分布式开关,为true,则对数据进行分片shard。4)file_format为数据解析格式,默认参数为"auto",自动解析文件数据格式。基于mushroom.conf配置, 提供训练数据地址:xgboost目录下demo/data/agaricus.txt.train,数据为libsvm格式。DMatrix::Load主流程如下:
  1. uri带有#符号,解析出cache_file文件,考虑分布式模式的情况;
  2. 分布式模式,获取当前主机排序partid以及主机总数npart,单机下partid=1;npart=1
  3. 满足file_format == "auto" && npart == 1,检测文件是否为二进制数据文件,检查开头魔方是否为SimpleCSRSource::kMagic,若是则LoadBinary直接初始化SimpleCSRSource数据源对象source
  4. 构建解析器parser,基于工厂设计模式,基类dmlc::Parser<uint32_t>调用静态方法Create(), 尽管格式解析为”auto”,目前使用libsvm格式解析。
  5. 解析器parser构建返回DMatrix对象:dmat = DMatrix::Create(parser.get(), cache_file),核心过程,后续详解。
  6. 根据参数slient打印相关统计信息,尝试读取.group后缀的文件;.base_margin后缀的boost初始设定值;.weight样本的权重,用于代价敏感学习,不存在则跳过。

1.2 解析器parser构建过程

  该小节先对步骤4进行详解,Parser<uint32_t>继承于DataIter<RowBlock<uint32_t> >,实际上内部数据以CSR格式。基于静态方法Parser<uint32_t>::Create()构建实例,该方法调用普通方法CreateParser_方法。由于配置最后解析为libsvm格式,使用ParserFactoryReg工厂来找到对应的LibSVMParser解析器对象。具体代码参考如下:

// 通过ptype=libsvm找到已注册的解析器工厂方法,Get()->Find()后面会看到
const ParserFactoryReg<IndexType>* e =
      Registry<ParserFactoryReg<IndexType> >::Get()->Find(ptype);
// 通过工厂方法生成解析器对象,最终调用的是CreateLibSVMParser<uint32_t>函数
return (*e->body)(spec.uri, spec.args, part_index, num_parts);

  理解上述的代码需要梳理以下的类定义与宏定义过程:

ParserFactoryReg解析器工厂类定义

// 工厂模式,通过宏定义来注册组件,继承FunctionRegEntryBase
// 第1个类参数必须是本身类型,第2个类参数是工厂方法。
// ParserFactoryReg会绑定工厂方法类型,通过工厂方法来组件对象,即Parser<IndexType>::Factor为函数类型
// 后期会调用DMLC_REGISTRY_ENABLE实例生成静态单例工厂对象。宏定义实现可扩展的工厂模式+对象单例非常精彩
template<typename IndexType>
struct ParserFactoryReg : public FunctionRegEntryBase
<ParserFactoryReg<IndexType>, typename Parser<IndexType>::Factory> {};

FunctionRegEntryBase工厂基类模板

// FunctionRegEntryBase方法注册类模板,需要注册项类型,方法类型
template<typename EntryType, typename FunctionType>
class FunctionRegEntryBase {
 public:
  std::string name;                         // 注册项名字
  std::string description;                  // 注册项描述
  std::vector<ParamFieldInfo> arguments;    
  • 2
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值