解析xgb模型成sql语句
xgb是一种功能强大,被广泛使用的树模型。树模型的本质是一组if-else组合。
训练好的xgb模型如何需要对大数据进行运算,往往需要分布式的环境。Hive是我们常用的处理数据的环境,往往一些模型运算也在其中。
我们可以使用以下三种方式,在Hive中计算xgb模型:
- 使用python tranform的形式,如果没有安装xgb的包,还需要自己解析模型文件;
- 将xgb模型解析成sql,运行hive sql;
- 用Java解析xgb模型,开发UDF函数。
本文通过递归的方式,尝试解析xgb模型为sql语句,帮助更好的理解xgb模型
# -*- coding: utf-8 -*-
#!/bin/python
import codecs
import json
## 解析单棵数
def parse_xgb_tree_2sql(xgb_tree_json, mid_sqls, tree_num, depth=0):
indent = " " * (depth+1) ## 用于对齐sql语句
if 'leaf' in xgb_tree_json.keys():
leaf_value = xgb_tree_json['leaf']
if(len(mid_sqls)>=1 and 'else' in mid_sqls[-1]):
cur_sql = indent + str(leaf_value) + ' '
else:
cur_sql = indent + str(leaf_value)
mid_sqls.append(cur_sql)
return
feat = xgb_tree_json['split']
value = str(xgb_tree_json['split_condition'])
left_tree = xgb_tree_json['yes']
right_tree = xgb_tree_json['no']
missing = xgb_tree_json['missing']
if missing == left_tree:
cur_sql = '(' + feat + ' is null' + ' or ' + feat + ' < ' + value + ')'
mid_sqls.append( "{}case when {} then\n".format(indent, cur_sql) )
parse_xgb_tree_2sql(xgb_tree_json['children'][0], mid_sqls, tree_num, depth+1)
cur_sql = '(' + feat + ' >= ' + value + ') '
mid_sqls.append( "\n{}else\n".format(indent) )
parse_xgb_tree_2sql(xgb_tree_json['children'][1], mid_sqls, tree_num, depth+1)
mid_sqls.append("\n{}end".format(indent))
elif missing == right_tree:
cur_sql = '(' + feat + ' is null' + ' or ' + feat + ' >= ' + value + ')'
mid_sqls.append( "{}case when {} then\n".format(indent, cur_sql) )
parse_xgb_tree_2sql(xgb_tree_json['children'][1], mid_sqls, tree_num, depth+1)
cur_sql = '(' + feat + ' < ' + value + ') '
mid_sqls.append( "\n{}else\n".format(indent) )
parse_xgb_tree_2sql(xgb_tree_json['children'][0], mid_sqls, tree_num, depth+1)
mid_sqls.append("\n{}end".format(indent))
else:
print "something wrong."
## 解析模型文件
def parse_xgb_trees(xgb_trees_josn):
tree_sqls = []
idx = 0
for single_tree in xgb_trees_josn:
mid_sqls = []
parse_xgb_tree_2sql(json.loads(single_tree), mid_sqls, idx, 0)
tree_sql = ''
for t_sql in mid_sqls:
tree_sql = tree_sql + t_sql
tree_sql = tree_sql + ' as ' + 'tree_' + str(idx) + '_score,'
idx += 1
tree_sqls.append(tree_sql + '\n')
tree_sqls[-1]=tree_sqls[-1][:-2]
return tree_sqls
if __name__ == '__main__':
'''
## xgb_model是训练好的模型类,此处为保存模型为json文件
xgb_json = xgb_model.get_dump(dump_format='json')
with codecs.open('gender_xgb.json', 'w', encoding="utf-8") as f:
for single_json in xgb_json:
single_json = single_json.replace('\n',' ').replace('\r', ' ')
f.write(single_json + '\n')
'''
with open('gender_xgb.json', 'r') as f_read:
xgb_json = f_read.readlines()
tree_sqls = parse_xgb_trees(xgb_json)
final_sqls = ''
for item_sql in tree_sqls:
final_sqls = final_sqls + item_sql
with codecs.open('xgb_model.sql', 'w', encoding="utf-8") as f:
for item_sql in tree_sqls:
f.write(item_sql + '\n')
#print final_sqls
模型文件示例:
{ "nodeid": 0, "depth": 0, "split": "f2354", "split_condition": 1393.75, "yes": 1, "no": 2, "missing": 2, "children": [ { "nodeid": 1, "depth": 1, "split": "f4927", "split_condition": 0, "yes": 3, "no": 4, "missing": 3, "children": [ { "nodeid": 3, "depth": 2, "split": "f2790", "split_condition": 95, "yes": 7, "no": 8, "missing": 8, "children": [ { "nodeid": 7, "depth": 3, "split": "f2388", "split_condition": 611, "yes": 15, "no": 16, "missing": 16, "children": [ { "nodeid": 15, "leaf": -0.100344 }, { "nodeid": 16, "leaf": -0.0164536 } ]}, { "nodeid": 8, "depth": 3, "split": "f302", "split_condition": 61, "yes": 17, "no": 18, "missing": 18, "children": [ { "nodeid": 17, "leaf": 0.0484211 }, { "nodeid": 18, "leaf": -0.130714 } ]} ]}, { "nodeid": 4, "depth": 2, "split": "f250", "split_condition": 3013, "yes": 9, "no": 10, "missing": 10, "children": [ { "nodeid": 9, "depth": 3, "split": "f1799", "split_condition": 5.5, "yes": 19, "no": 20, "missing": 19, "children": [ { "nodeid": 19, "leaf": 0.133858 }, { "nodeid": 20, "leaf": 0.00789474 } ]}, { "nodeid": 10, "depth": 3, "split": "f2391", "split_condition": 0, "yes": 21, "no": 22, "missing": 21, "children": [ { "nodeid": 21, "leaf": 0.0243929 }, { "nodeid": 22, "leaf": -0.0782429 } ]} ]} ]}, { "nodeid": 2, "depth": 1, "split": "f4534", "split_condition": 280.5, "yes": 5, "no": 6, "missing": 6, "children": [ { "nodeid": 5, "depth": 2, "split": "f2387", "split_condition": 720.5, "yes": 11, "no": 12, "missing": 12, "children": [ { "nodeid": 11, "depth": 3, "split": "f278", "split_condition": 57, "yes": 23, "no": 24, "missing": 24, "children": [ { "nodeid": 23, "leaf": 0.113749 }, { "nodeid": 24, "leaf": -0.064978 } ]}, { "nodeid": 12, "depth": 3, "split": "f4796", "split_condition": 4499.5, "yes": 25, "no": 26, "missing": 26, "children": [ { "nodeid": 25, "leaf": 0.155554 }, { "nodeid": 26, "leaf": 0.0842534 } ]} ]}, { "nodeid": 6, "depth": 2, "split": "f5718", "split_condition": 263, "yes": 13, "no": 14, "missing": 14, "children": [ { "nodeid": 13, "depth": 3, "split": "f4793", "split_condition": 0, "yes": 27, "no": 28, "missing": 27, "children": [ { "nodeid": 27, "leaf": -0.114908 }, { "nodeid": 28, "leaf": -0.0101812 } ]}, { "nodeid": 14, "depth": 3, "split": "f4798", "split_condition": 310.5, "yes": 29, "no": 30, "missing": 30, "children": [ { "nodeid": 29, "leaf": 0.131054 }, { "nodeid": 30, "leaf": 0.0196653 } ]} ]} ]} ]}
解析后hive sql语句
case when (f2354 is null or f2354 >= 1393.75) then
case when (f4534 is null or f4534 >= 280.5) then
case when (f5718 is null or f5718 >= 263) then
case when (f4798 is null or f4798 >= 310.5) then
0.0196653
else
0.131054
end
else
case when (f4793 is null or f4793 < 0) then
-0.114908
else
-0.0101812
end
end
else
case when (f2387 is null or f2387 >= 720.5) then
case when (f4796 is null or f4796 >= 4499.5) then
0.0842534
else
0.155554
end
else
case when (f278 is null or f278 >= 57) then
-0.064978
else
0.113749
end
end
end
else
case when (f4927 is null or f4927 < 0) then
case when (f2790 is null or f2790 >= 95) then
case when (f302 is null or f302 >= 61) then
-0.130714
else
0.0484211
end
else
case when (f2388 is null or f2388 >= 611) then
-0.0164536
else
-0.100344
end
end
else
case when (f250 is null or f250 >= 3013) then
case when (f2391 is null or f2391 < 0) then
0.0243929
else
-0.0782429
end
else
case when (f1799 is null or f1799 < 5.5) then
0.133858
else
0.00789474
end
end
end
end as tree_0_score,