三分类SHAP图(特征标准化之后怎么画)

画三分类SHAP图出错

今天干了一件很蠢的事情,还耽误了很多时间,特此记录一下
我将数据标准化之后训练模型,然后将未标准化的数据作为输入计算了SHAP值,得出的结果显然不对。类似于下图这种
在这里插入图片描述
但是如果画图时将X_test输入作为参数,那么横坐标就对应的是标准化之后的值,所以我们可以先对X_test未经标准化时候制作一个copy版本X_test1,然后作为画图时候参数输入就可以正确画出SHAP图的横坐标了,也可以得到我们想要的信息。另外三分类shap values得到一个3维数据,有时候使用起来需要切片,比如画单个特征的shap图,但是画总体概览图时候不用。

import shap
X_test = pd.DataFrame(X_test,columns=x_test_cols)
explainer = shap.TreeExplainer(lgb_model)
shap_values = explainer.shap_values(X_test)  # 传入特征矩阵X,计算SHAP值
plt.figure()
#plt.rcParams['figure.dpi'] = 300 #分辨率
plt.title('LightGBM model SHAP values')
shap.summary_plot(shap_values, X_test,show=False)
plt.savefig(save_path+'\shap'+'lgb.png',dpi=300,bbox_inches = 'tight')

shap.initjs()
shap.dependence_plot('Na1', shap_values[1], X_test,interaction_index=None,show=False) #注意:如皋这么画,那么SHAP横坐标就是标准化之后的值
plt.axhline(y=0, color="red",linestyle='-')

#shap.force_plot(explainer.expected_value[0], shap_values[0][0,:], X_test.iloc[0,:])
#shap.force_plot(explainer.expected_value[0], shap_values[0], X_test)
#shap.dependence_plot("Na1", shap_values[1], X_test)

import os
shap_path = save_path +r'\class1'
if not os.path.isdir(shap_path):
    os.makedirs(shap_path)
for i in X_test.columns.values.tolist():
    plt.figure()
    shap.dependence_plot(i, shap_values[1], X_test1,interaction_index=None,show=False)
    plt.axhline(y=0, color="red",linestyle='-') #X_test1是X_test的一个未经标准化的复制版本
    plt.savefig(shap_path+ "\shap"+str(i)+'.png',dpi=300,bbox_inches = 'tight')

下面是一张正确的结果图
在这里插入图片描述

  • 4
    点赞
  • 20
    收藏
    觉得还不错? 一键收藏
  • 7
    评论
评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值