5.1 可视化
人类天生对图形的获取信息能力很强,通过图形化的展示来表达模型过程和结论可以有效的帮助分析人员进行判断,xgboost内置了树的画图函数以及变量重要性的画图函数,展示树我们可以直观的看到树的分裂过程,哪些变量以什么样的方式在影响结果,展示变量重要性可以从全局角度了解哪些变量更加重要。以下分别为树和变量重要性的画图代码:
def visual_tree(model, num_tree):
xgb.plot_tree(model, num_trees=num_tree)
def visual_var_importance(model):
xgb.plot_importance(model, importance_type='weight')
def main():
data_path = './data/boston/boston.csv'
train_x, train_y, valid_x, valid_y = get_train_valid_boston(data_path)
model = regression(train_x, train_y, valid_x, valid_y)
visual_tree(model, 0)
visual_var_importance(model)
plt.show()
因为模型有很多棵树,我们挑一个进行展示,这里选的是第一棵树,运行结果如下。第一幅图是树的可视化,包含生个分裂点的属性,属性分裂阈值,最终叶节点的值,可以看到一个属性在树不同深度是可以多次出现的。第二幅图是变量重要性的可视化,左侧为变量名,右侧为重要性值,这里默认为weight值,实际上重要性有三种表示方法,分别是weight、gain、cover,其中weight代表每个属性在这个森林里出现的次数,每个节点算一次;gain代表每个属性在每次分裂时的信息增益的平均值;cover代表每个属性在每次分裂时所对应的样本数的平均值。
5.2 保存模型与加载
当模型比较成熟以后,需要落地,交付生产,这个时候就需要把训练好的模型保存下来,应用模型时,从本地加载,然后可以调用接口进行预测,这一系列过程用起来非常简单。
def save_model(model, model_path):
model.save_model(model_path)
def load_model(model_path):
model = XGBRegressor()
model.load_model(model_path)
return model
# main 函数
def main():
data_path = './data/boston/boston.csv'
train_x, train_y, valid_x, valid_y = get_train_valid_boston(data_path)
model = regression(train_x, train_y, valid_x, valid_y)
model_path = './data/model/xgb_reg_model_201912121124.bin'
save_model(model, model_path)
model_load = load_model(model_path)
pred_y = model_load.predict(valid_x)
print(pred_y)
5.3 森林信息
前面我们通过画图展示过一棵树的结构,有些人可能会觉着一棵树表达的信息不完整,是否可以将森林里每一棵树都打印出来,这样可以才可以比较放心?答案是肯定的,有没有发现xgboost类库几乎替你考虑到了所有你想到的,就像python的语法总有一处处让你惊喜的地方。还是先上代码:
def save_forest_info(model, model_path):
model.dump_model(model_path)
def main():
data_path = './data/boston/boston.csv'
train_x, train_y, valid_x, valid_y = get_train_valid_boston(data_path)
model = regression(train_x, train_y, valid_x, valid_y)
model_path = './data/model/xgb_reg_model_201912121124.txt'
save_forest_info(model, model_path)
可以发现,代码很简单,只有一句话,参数为模型保存路径,我们仍然用前面的回归模型,运行结果文件为所有的树结构信息,除了分裂节点编号,分裂条件,叶节点的值以外,还有每个叶节点的样本分布情况。拷贝出第一棵树和最后一棵树,如下所示。
booster[0]:
0:[RM<6.85300016] yes=1,no=2,missing=1
1:[LSTAT<14.3950005] yes=3,no=4,missing=3
3:[LSTAT<5.51000023] yes=7,no=8,missing=7
7:leaf=2.87136364
8:leaf=2.16296291
4:[CRIM<6.9257946] yes=9,no=10,missing=9
9:leaf=1.64333332
10:leaf=1.13913798
2:[RM<7.44499969] yes=5,no=6,missing=5
5:[CRIM<12.6701298] yes=11,no=12,missing=11
11:leaf=3.07564116
12:leaf=0.813333333
6:[CRIM<2.74223518] yes=13,no=14,missing=13
13:leaf=4.33173895
14:leaf=1.07000005
......
booster[151]:
0:[AGE<8.10000038] yes=1,no=2,missing=1
1:[LSTAT<5.35500002] yes=3,no=4,missing=3
3:leaf=-0.0826229081
4:[CRIM<0.0899500027] yes=7,no=8,missing=7
7:leaf=0.0183633808
8:leaf=-0.016655311
2:[AGE<36.1500015] yes=5,no=6,missing=5
5:[RM<6.65649986] yes=9,no=10,missing=9
9:leaf=0.00017928562
10:leaf=0.0459366888
6:[RM<6.73850012] yes=11,no=12,missing=11
11:leaf=0.00154620328
12:leaf=-0.0204947647