参考pdpbox官方文档中的其他类,绘制相应的图,任选即可
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
from pdpbox import pdp_interact, pdp_interact_plot
data = pd.read_csv('titanic.csv')
X = data[['Age', 'Fare', 'Pclass', 'Sex']]
y = data['Survived']
# 简单训练一个分类模型
model = RandomForestClassifier()
model.fit(X, y)
interact_features = ['Age', 'Fare']
# 实例化pdp_interact类
interact_out = pdp_interact(
model=model,
dataset=X,
model_features=X.columns.tolist(),
features=interact_features,
num_grid_points=10 # 每个特征网格点数
)
# 生成热力图(展示二维交互影响)
pdp_plot = pdp_interact_plot(
pdp_interact_out=interact_out,
feature_names=interact_features,
plot_type='grid', # 可选'contour'等高线
plot_pts_dist=True # 显示数据点分布
)
# 自定义样式(如标题、颜色)
pdp_plot['pdp_interact_plot'].figure.suptitle('Interaction between Age and Fare', y=1.05)
pdp_plot['pdp_interact_plot'].set_colormap('viridis') # 颜色映射
506

被折叠的 条评论
为什么被折叠?



