通用预处理
import shap
# 创建Explainer(以树模型为例)
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(X_test)
一、全局解释 (Global Interpretation)
1. 特征重要性条形图 (Summary Bar Plot)
import shap
import matplotlib.pyplot as plt
# 统一处理多分类问题
if isinstance(shap_values, list):
class_shap = shap_values[0] # 二分类选第一个类别
else:
class_shap = shap_values
# 绘制条形图
shap.summary_plot(class_shap, X_test, plot_type="bar", show=False)
plt.title("SHAP Feature Importance (Bar Plot)")
plt.tight_layout()
plt.show()
2. 特征重要性散点图 (Summary Dot Plot)
shap.summary_plot(class_shap, X_test, plot_type="dot", show=False)
plt.title("SHAP Feature Importance (Dot Plot)")
plt.tight_layout()
plt.show()
3. 交互热力图 (Heatmap Plot)
shap.plots.heatmap(class_shap, show=False)
plt.title("SHAP Value Heatmap")
plt.tight_layout()
plt.show()
二、特征效应分析 (Feature Effect)
4. 特征依赖图 (Dependence Plot)
feature_idx = 0 # 修改为目标特征索引
shap.dependence_plot(feature_idx,
class_shap,
X_test,
interaction_index="auto", # 自动检测交互特征
show=False)
plt.title(f"Dependence Plot - {X_test.columns[feature_idx]}")
plt.tight_layout()
plt.show()
5. 交互作用图 (Interaction Plot)
# 计算交互值(需模型支持)
shap_interaction = explainer.shap_interaction_values(X_test)
shap.summary_plot(shap_interaction[0], X_test, show=False)
plt.title("Interaction Values Summary")
plt.tight_layout()
plt.show()
三、局部解释 (Local Interpretation)
6. 单个样本解释-瀑布图 (Waterfall Plot)
sample_idx = 0 # 样本索引
shap.plots.waterfall(class_shap[sample_idx],
max_display=15,
show=False)
plt.title(f"Waterfall Plot - Sample {sample_idx}")
plt.tight_layout()
plt.show()
7. 单个样本解释-力图 (Force Plot)
shap.plots.force(explainer.expected_value[0],
class_shap[sample_idx],
X_test.iloc[sample_idx],
matplotlib=True, # 用Matplotlib渲染
show=False)
plt.title(f"Force Plot - Sample {sample_idx}")
plt.tight_layout()
plt.show()
四、多样本分析 (Multi-Sample)
8. 蜂群图 (Beeswarm Plot)
shap.plots.beeswarm(class_shap, show=False)
plt.title("Beeswarm Plot")
plt.tight_layout()
plt.show()
9. 决策路径图 (Decision Plot)
shap.decision_plot(explainer.expected_value[0],
class_shap[:50], # 显示前50样本
feature_names=X_test.columns,
show=False)
plt.title("Decision Path Visualization")
plt.tight_layout()
plt.show()
五、高级分析
10. 特征聚类分析 (Clustering Plot)
shap.plots.bar(class_shap.cohorts(2).abs.mean(0), # 分2个聚类
show=False)
plt.title("Feature Importance by Clusters")
plt.tight_layout()
plt.show()