基于GBDT的在线预测任务深度学习框架
A Deep Learning Framework Distilled by GBDT for Online Prediction Tasks
https://github.com/motefly/DeepGBM
该框架主要使用CatNN、GBDT2NN两部分实现非结构化数据和结构化数据同时处理。
本文优点:
(1):CatNN主要侧重于处理Spares Categories Feature;
(2):GBDT2NN主要侧重于处理Dense Numerical Feature;
(3):online update端到端训练与在线更新
GBDT2NN解决FC全链接层稠密数值特征学习超平面优化陷入局部最优问题
catnn主要利用了embedding嵌入技术将高维度稀疏向量转为稠密向量,GBDT2NN采用GBDT
将数据划分为区域聚类成叶子,然后使用神经网络逼近树结构输出并蒸馏并且扩展到多树
逼近神经网络模型。
训练过程:
离线训练获得GBDT模型,然后GBDT叶子节点嵌入实现端到端训练。在线更新离线模型
时候丢弃分类任务交叉熵损失函数来达到更新嵌入学习部分而不改变GBDT数的特征从而保
证更新有效性。