原始的shap一般是直接show出特征,
需求是保存多张图,做特征变化的对比
直接改shap.summary_plot源码可以实现
函数参数增加save=False,path=False
在summary_plot函数最下面增加
if save:
pl.savefig(path)
pl.close()
这里必须要close掉图层,要不然会出现多层叠加的问题
直接使用代码
explainer = shap.TreeExplainer(model)#模型训练用什么矩阵形状,这里要对应
shap_values = explainer.shap_values(X_train) # 传入特征矩阵X,计算SHAP值
shap.summary_plot(shap_values, X_train, plot_type="bar",max_display=50,show=False,\
save=True,path='./fac_importance/%s.png'%(d))
shap.summary_plot(shap_values, X_train,max_display=50,show=False,\
save=True,path='./shap/%s.png'%(d))