导语
上一节介绍了常用的bar,force,partial_dependence函数,本节介绍几种常用的plots函数
内容一: 绘制Beeswarm图
内容二: 通过decision函数解释模型是如何产生预测
内容三:通过Waterfall 每个特征是如何一步步累加形成模型的预测
scikit-learn 0.23.2, shap 0.42.1
该教程只能保证上述版本能够正常运行
Beeswarm plot
Beeswarm图显示数据集中热门要素如何影响模型输出的信息密集型摘要。给定解释的每个实例由每个特征图上的单个点表示。 它的位置揭示了特征对模型预测的贡献。通过Beeswarm图,我们可以迅速识别哪些特征对模型的决策影响最大。
第一种实现方法
import shap
import xgboost as xgb
import matplotlib.pyplot as pl
x,y= shap.datasets.adult()
seed=12345
xgb_model=xgb.XGBRegressor(random_state=seed).fit(x,y)
explainer = shap.Explainer(model, x)
shap_values = explainer(x)
fig= pl.gcf()
ax= pl.gca()
shap.summary_plot(shap_values, x,show=False)
pl.gca().spines['right'].set_visible(True)
pl.gca().spines['left'].set_visible(True)
pl.gca().spines['top'].set_visible(True)
pl.savefig('beewarm-1.png',bbox_inches='tight',dpi=400)
第二种实现方法
shap.plots.beeswarm(shap_values,show=False, max_display=12)
Decision plot
Decision图。它以树状图的形式展示了模型的决策过程。每个节点代表一个特征,每个分支代表一个特征的取值。通过Decision图,我们可以清晰地看到模型是如何一步步从根节点走到叶节点的,这有助于我们理解模型的决策逻辑。
再次吐槽开发者,在document中 API reference和API example中两个函数居然不相同
import shap
import xgboost as xgb
import matplotlib.pyplot as pl
x,y= shap.datasets.adult()
seed=12345
xgb_model=xgb.XGBRegressor(random_state=seed).fit(x,y)
explainer = shap.Explainer(model, x)
shap_values = explainer(x)
for a in range(shap_values.data.shape[0]):
shap.plots.decision(shap_values[a].base_values,shap_values[a].values,show=False,color_bar=False)
pl.gca().spines['right'].set_visible(True)
pl.gca().spines['left'].set_visible(True)
pl.gca().spines['top'].set_visible(True)
pl.savefig('decision_plot.png',bbox_inches='tight',dpi=400)
图中展示了每一个样本的示例,你也可以只绘制你需要的decision
Waterfall plot
最后,我们来介绍一下Waterfall图。它以瀑布流的形式展示了特征对模型预测的累积贡献。通过Waterfall图,我们可以直观地看到每个特征是如何一步步累加,最终形成模型的预测结果的。这对于理解模型的非线性关系特别有帮助。
import shap
import xgboost as xgb
import matplotlib.pyplot as pl
x,y= shap.datasets.adult()
seed=12345
xgb_model=xgb.XGBRegressor(random_state=seed).fit(x,y)
explainer = shap.Explainer(model, x)
shap_values = explainer(x)
shap.plots.waterfall(shap_values[20],show=False)
pl.gca().spines['right'].set_visible(True)
pl.gca().spines['left'].set_visible(True)
pl.gca().spines['top'].set_visible(True)
pl.savefig('decision_plot.png',bbox_inches='tight',dpi=400)
一行代码直出,如果你想生成每个样本的展示,或许可以这样修改
for a in range(shap_values.data.shape[0]):
shap.plots.waterfall(shap_values[20],show=False)
pl.gca().spines['right'].set_visible(True)
pl.gca().spines['left'].set_visible(True)
pl.gca().spines['top'].set_visible(True)
pl.savefig('decision_plot-%d.png'%a,bbox_inches='tight',dpi=400)
总结:
未完待续~