如何将XGBOOT等模型转成Hive-SQL

我们在项目开发的时候都有这样的需求,就是将xgboost等线下训练好的模型部署到线上每天对存量数据扫库,这里参考大神记录下,方便以后使用到

来源:https://blog.csdn.net/wgdzz/article/details/87990598

# -*- coding: utf-8 -*-
#!/bin/python

import codecs
import json

## 解析单棵数
def parse_xgb_tree_2sql(xgb_tree_json, mid_sqls, tree_num, depth=0):
    indent = "    " * (depth+1)  ## 用于对齐sql语句
    if 'leaf' in xgb_tree_json.keys():
        leaf_value = xgb_tree_json['leaf']
        if(len(mid_sqls)>=1 and 'else' in mid_sqls[-1]):
            cur_sql = indent + str(leaf_value) + ' '
        else:
            cur_sql = indent + str(leaf_value)
        mid_sqls.append(cur_sql)
        return

    feat = xgb_tree_json['split']
    value = str(xgb_tree_json['split_condition'])
    left_tree = xgb_tree_json['yes']
    right_tree = xgb_tree_json['no']
    missing = xgb_tree_json['missing']

    if missing == left_tree:

        cur_sql = '(' + feat + ' is null' + ' or ' + feat + ' < ' + value + ')'
        mid_sqls.append( "{}case when {} then\n".format(indent, cur_sql) )
        parse_xgb_tree_2sql(xgb_tree_json['children'][0], mid_sqls, tree_num, depth+1)

        cur_sql = '(' + feat + ' >= ' + value + ') '
        mid_sqls.append( "\n{}else\n".format(indent) ) 
        parse_xgb_tree_2sql(xgb_tree_json['children'][1], mid_sqls, tree_num, depth+1)
        mid_sqls.append("\n{}end".format(indent))

    elif missing == right_tree:

        cur_sql = '(' + feat + ' is null' + ' or ' + feat + ' >= ' + value + ')'
        mid_sqls.append( "{}case when {} then\n".format(indent, cur_sql) )
        parse_xgb_tree_2sql(xgb_tree_json['children'][1], mid_sqls, tree_num, depth+1)

        cur_sql = '(' + feat + ' < ' + value + ') '
        mid_sqls.append( "\n{}else\n".format(indent) )
        parse_xgb_tree_2sql(xgb_tree_json['children'][0], mid_sqls, tree_num, depth+1)
        mid_sqls.append("\n{}end".format(indent))

    else:
        print "something wrong."
   
## 解析模型文件
def parse_xgb_trees(xgb_trees_josn):
    tree_sqls = []

    idx = 0
    for single_tree in xgb_trees_josn:
        mid_sqls = []
        parse_xgb_tree_2sql(json.loads(single_tree), mid_sqls, idx, 0)
        tree_sql = ''
        for t_sql in mid_sqls:
            tree_sql = tree_sql + t_sql
        tree_sql = tree_sql + ' as ' + 'tree_' + str(idx) + '_score,'    
        idx += 1
        tree_sqls.append(tree_sql + '\n')
    tree_sqls[-1]=tree_sqls[-1][:-2]
    return tree_sqls

if __name__ == '__main__':
    
    '''
    ## xgb_model是训练好的模型类,此处为保存模型为json文件
 	xgb_json = xgb_model.get_dump(dump_format='json')
	with codecs.open('gender_xgb.json', 'w', encoding="utf-8") as f:
    	for single_json in xgb_json:
        	single_json = single_json.replace('\n',' ').replace('\r', ' ')
        	f.write(single_json + '\n')  
    '''
    
    with open('gender_xgb.json', 'r') as f_read:
        xgb_json = f_read.readlines()

    tree_sqls =  parse_xgb_trees(xgb_json)

    final_sqls = ''
    for item_sql in tree_sqls:
        final_sqls = final_sqls + item_sql

    with codecs.open('xgb_model.sql', 'w', encoding="utf-8") as f:
        for item_sql in tree_sqls:
            f.write(item_sql + '\n')
    
    #print final_sqls

模型文件示例:

 { "nodeid": 0, "depth": 0, "split": "f2354", "split_condition": 1393.75, "yes": 1, "no": 2, "missing": 2, "children": [     { "nodeid": 1, "depth": 1, "split": "f4927", "split_condition": 0, "yes": 3, "no": 4, "missing": 3, "children": [       { "nodeid": 3, "depth": 2, "split": "f2790", "split_condition": 95, "yes": 7, "no": 8, "missing": 8, "children": [         { "nodeid": 7, "depth": 3, "split": "f2388", "split_condition": 611, "yes": 15, "no": 16, "missing": 16, "children": [           { "nodeid": 15, "leaf": -0.100344 },           { "nodeid": 16, "leaf": -0.0164536 }         ]},         { "nodeid": 8, "depth": 3, "split": "f302", "split_condition": 61, "yes": 17, "no": 18, "missing": 18, "children": [           { "nodeid": 17, "leaf": 0.0484211 },           { "nodeid": 18, "leaf": -0.130714 }         ]}       ]},       { "nodeid": 4, "depth": 2, "split": "f250", "split_condition": 3013, "yes": 9, "no": 10, "missing": 10, "children": [         { "nodeid": 9, "depth": 3, "split": "f1799", "split_condition": 5.5, "yes": 19, "no": 20, "missing": 19, "children": [           { "nodeid": 19, "leaf": 0.133858 },           { "nodeid": 20, "leaf": 0.00789474 }         ]},         { "nodeid": 10, "depth": 3, "split": "f2391", "split_condition": 0, "yes": 21, "no": 22, "missing": 21, "children": [           { "nodeid": 21, "leaf": 0.0243929 },           { "nodeid": 22, "leaf": -0.0782429 }         ]}       ]}     ]},     { "nodeid": 2, "depth": 1, "split": "f4534", "split_condition": 280.5, "yes": 5, "no": 6, "missing": 6, "children": [       { "nodeid": 5, "depth": 2, "split": "f2387", "split_condition": 720.5, "yes": 11, "no": 12, "missing": 12, "children": [         { "nodeid": 11, "depth": 3, "split": "f278", "split_condition": 57, "yes": 23, "no": 24, "missing": 24, "children": [           { "nodeid": 23, "leaf": 0.113749 },           { "nodeid": 24, "leaf": -0.064978 }         ]},         { "nodeid": 12, "depth": 3, "split": "f4796", "split_condition": 4499.5, "yes": 25, "no": 26, "missing": 26, "children": [           { "nodeid": 25, "leaf": 0.155554 },           { "nodeid": 26, "leaf": 0.0842534 }         ]}       ]},       { "nodeid": 6, "depth": 2, "split": "f5718", "split_condition": 263, "yes": 13, "no": 14, "missing": 14, "children": [         { "nodeid": 13, "depth": 3, "split": "f4793", "split_condition": 0, "yes": 27, "no": 28, "missing": 27, "children": [           { "nodeid": 27, "leaf": -0.114908 },           { "nodeid": 28, "leaf": -0.0101812 }         ]},         { "nodeid": 14, "depth": 3, "split": "f4798", "split_condition": 310.5, "yes": 29, "no": 30, "missing": 30, "children": [           { "nodeid": 29, "leaf": 0.131054 },           { "nodeid": 30, "leaf": 0.0196653 }         ]}       ]}     ]}   ]}

解析后hive sql语句

    case when (f2354 is null or f2354 >= 1393.75) then
        case when (f4534 is null or f4534 >= 280.5) then
            case when (f5718 is null or f5718 >= 263) then
                case when (f4798 is null or f4798 >= 310.5) then
                    0.0196653
                else
                    0.131054 
                end
            else
                case when (f4793 is null or f4793 < 0) then
                    -0.114908
                else
                    -0.0101812 
                end
            end
        else
            case when (f2387 is null or f2387 >= 720.5) then
                case when (f4796 is null or f4796 >= 4499.5) then
                    0.0842534
                else
                    0.155554 
                end
            else
                case when (f278 is null or f278 >= 57) then
                    -0.064978
                else
                    0.113749 
                end
            end
        end
    else
        case when (f4927 is null or f4927 < 0) then
            case when (f2790 is null or f2790 >= 95) then
                case when (f302 is null or f302 >= 61) then
                    -0.130714
                else
                    0.0484211 
                end
            else
                case when (f2388 is null or f2388 >= 611) then
                    -0.0164536
                else
                    -0.100344 
                end
            end
        else
            case when (f250 is null or f250 >= 3013) then
                case when (f2391 is null or f2391 < 0) then
                    0.0243929
                else
                    -0.0782429 
                end
            else
                case when (f1799 is null or f1799 < 5.5) then
                    0.133858
                else
                    0.00789474 
                end
            end
        end
    end as tree_0_score,

 

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值