【MATLAB第69期】#源码分享 | 基于MATLAB的SHAP (SHapley Additive exPlanations)解释模型预测方法的多种实现方式合集

【MATLAB第69期】#源码分享 | 基于MATLAB的SHAP (SHapley Additive exPlanations)解释模型预测方法的多种实现方式合集


shapley函数需要在 matlab R2021a版本以上运行
R2023a版本:支持线性SHAP和树形SHAP算法


一、shapley创建对象时计算 Shapley 值

训练分类模型并创建shapley对象。创建shapley对象时,指定查询点,以便软件计算查询点的 Shapley 值。然后使用对象函数 创建 Shapley 值的条形图plot。

加载CreditRating_Historical数据集。该数据集包含客户 ID 及其财务比率、行业标签和信用评级。

tbl = readtable('CreditRating_Historical.dat');
head(tbl,3)
blackbox = fitcecoc(tbl,'Rating', ...
    'PredictorNames',tbl.Properties.VariableNames(2:7), ...
    'CategoricalPredictors','Industry', ...
    'ClassNames',{'AAA' 'AA' 'A' 'BBB' 'BB' 'B' 'CCC'});


使用该函数训练信用评级黑盒模型fitcecoc。使用第二列到第七列的变量作为tbl预测变量。推荐的做法是指定类名称来设置类的顺序。

queryPoint = tbl(end,:)
explainer = shapley(blackbox,'QueryPoint',queryPoint)

创建一个shapley对象来解释最后一次观察的预测。指定查询点,以便软件计算 Shapley 值并将其存储在属性中ShapleyValues。
计算可能会很慢,因为预测数据有超过 1000 个观测值。使用训练集的较小样本或将“UseParallel”指定为 true 以加快计算速度。

绘图

plot(explainer)

在这里插入图片描述
水平条形图显示所有变量的 Shapley 值,按绝对值排序。每个 Shapley 值都解释了由于相应的变量而导致的查询点分数与预测类平均分数的偏差。

二、使用以下命令创建shapley对象并计算 Shapley 值fit

训练回归模型并创建一个shapley对象。创建shapley对象时,如果未指定查询点,则软件不会计算 Shapley 值。使用对象函数fit计算指定查询点的 Shapley 值。然后使用对象函数 创建 Shapley 值的条形图plot。

加载carbig数据集,其中包含 20 世纪 70 年代和 80 年代初制造的汽车的测量数据。

load carbig
tbl = table(Acceleration,Cylinders,Displacement,Horsepower,Model_Year,Weight,MPG);%创建一个包含预测变量Acceleration、Cylinders等以及响应变量 的表MPG。
tbl = rmmissing(tbl);%删除训练集中的缺失值有助于减少内存消耗并加快函数的训练速度fitrkernel。删除 中的缺失值tbl。
rng('default') % For reproducibility
mdl = fitrkernel(tbl,'MPG','CategoricalPredictors',[2 5]);%MPG使用该fitrkernel函数训练黑盒模型
explainer = shapley(mdl,tbl)%创建一个shapley对象。指定数据集tbl,因为mdl不包含训练数据。explainer将训练数据存储tbl在X属性中。计算 中第一个观测值的所有预测变量的 Shapley 值tbl。
queryPoint = tbl(1,:)
explainer = fit(explainer,queryPoint);
explainer.ShapleyValues%对于回归模型,shapley使用预测响应计算 Shapley 值,并将其存储在ShapleyValues属性中。显示属性中的值ShapleyValues。
plot(explainer)%使用该plot函数绘制查询点的 Shapley 值。

在这里插入图片描述
水平条形图显示所有变量的 Shapley 值,按绝对值排序。每个 Shapley 值都解释了由于相应变量而导致的查询点预测与平均值的偏差。

三、使用函数句柄指定黑盒模型

训练回归模型并使用模型函数的shapley函数句柄创建对象。predict使用对象函数fit计算指定查询点的 Shapley 值。然后使用对象函数 绘制 Shapley 值plot。

加载carbig数据集,其中包含 20 世纪 70 年代和 80 年代初制造的汽车的测量数据。

load carbig
tbl = table(Acceleration,Cylinders,Displacement,Horsepower,Model_Year,Weight);
rng('default') % For reproducibility
Mdl = TreeBagger(100,tbl,MPG,'Method','regression','CategoricalPredictors',[2 5]);%MPG使用该函数训练黑盒模型TreeBagger。

f = @(tbl) predict(Mdl,tbl,'Trees',1:50);%shapley不直接支持TreeBagger对象,因此您无法将 的第一个输入参数(黑盒模型)指定为shapley对象TreeBagger。相反,您可以使用该函数的函数句柄predict。您还可以使用函数的名称-值参数指定函数的选项predict。创建对象predict函数的函数句柄。指定要用作 的树索引数组。
explainer = shapley(f,tbl,'CategoricalPredictors',[2 5]);%shapley使用函数句柄创建一个对象f。当您指定黑盒模型作为函数句柄时,您必须提供预测变量数据。tbl包括具有数据类型的分类预测变量 (Cylinder和Model_Year) double。默认情况下,shapley不将具有该double数据类型的变量视为分类预测变量。将第二个 ( Cylinder) 和第五个 ( Model_Year) 变量指定为分类预测变量。
explainer = fit(explainer,tbl(1,:));

plot(explainer)

在这里插入图片描述

四、随机森林RF模型

tbl = readtable('CreditRating_Historical.dat');
%Display the first three rows of the table.

head(tbl,3)
%Create a table of predictor variables by removing the columns containing customer IDs and ratings from tbl.

tblX = removevars(tbl,["ID","Rating"]);
%Train an ensemble of bagged decision trees by using the fitcensemble function and specifying the ensemble aggregation method as random forest ('Bag'). For reproducibility of the random forest algorithm, specify the 'Reproducible' name-value argument as true for tree learners. Also, specify the class names to set the order of the classes in the trained model.

rng('default') % For reproducibility
t = templateTree('Reproducible',true);
blackbox = fitcensemble(tblX,tbl.Rating, ...
    'Method','Bag','Learners',t, ...
    'CategoricalPredictors','Industry', ...
    'ClassNames',{'AAA' 'AA' 'A' 'BBB' 'BB' 'B' 'CCC'});
%blackbox is a ClassificationBaggedEnsemble model.

%Use Model-Specific Interpretability Features

%ClassificationBaggedEnsemble supports two object functions, oobPermutedPredictorImportance and predictorImportance, which find important predictors in the trained model.

%Estimate out-of-bag predictor importance by using the oobPermutedPredictorImportance function. The function randomly permutes out-of-bag data across one predictor at a time, and estimates the increase in the out-of-bag error due to this permutation. The larger the increase, the more important the feature.

Imp1 = oobPermutedPredictorImportance(blackbox);
%Estimate predictor importance by using the predictorImportance function. The function estimates predictor importance by summing changes in the node risk due to splits on each predictor and dividing the sum by the number of branch nodes.

Imp2 = predictorImportance(blackbox);
%Create a table containing the predictor importance estimates, and use the table to create horizontal bar graphs. To display an existing underscore in any predictor name, change the TickLabelInterpreter value of the axes to 'none'.

table_Imp = table(Imp1',Imp2', ...
    'VariableNames',{'Out-of-Bag Permuted Predictor Importance','Predictor Importance'}, ...
    'RowNames',blackbox.PredictorNames);
tiledlayout(1,2)
ax1 = nexttile;
table_Imp1 = sortrows(table_Imp,'Out-of-Bag Permuted Predictor Importance');
barh(categorical(table_Imp1.Row,table_Imp1.Row),table_Imp1.('Out-of-Bag Permuted Predictor Importance'))
xlabel('Out-of-Bag Permuted Predictor Importance')
ylabel('Predictor')
ax2 = nexttile;
table_Imp2 = sortrows(table_Imp,'Predictor Importance');
barh(categorical(table_Imp2.Row,table_Imp2.Row),table_Imp2.('Predictor Importance'))
xlabel('Predictor Importance')
ax1.TickLabelInterpreter = 'none';
ax2.TickLabelInterpreter = 'none';


%Both object functions identify MVE_BVTD and RE_TA as the two most important predictors.

%Specify Query Point

%Find the observations whose Rating is 'AAA' and choose four query points among them.

rng('default')
tblX_AAA = tblX(strcmp(tbl.Rating,'AAA'),:);
queryPoint = datasample(tblX_AAA,4,'Replace',false)

%Use LIME with Linear Simple Models

%Explain the predictions for the query points using lime with linear simple models. lime generates a synthetic data set and fits a simple model to the synthetic data set.

%Create a lime object using tblX_AAA so that lime generates a synthetic data set using only the observations whose Rating is 'AAA', not the entire data set.

explainer_lime = lime(blackbox,tblX_AAA);
%The default value of DataLocality for lime is 'global', which implies that, by default, lime generates a global synthetic data set and uses it for any query points. lime uses different observation weights so that weight values are more focused on the observations near the query point. Therefore, you can interpret each simple model as an approximation of the trained model for a specific query point.

%Fit simple models for the four query points by using the object function fit. Specify the third input (the number of important predictors to use in the simple model) as 6 to use all six predictors.

explainer_lime1 = fit(explainer_lime,queryPoint(1,:),6);
explainer_lime2 = fit(explainer_lime,queryPoint(2,:),6);
explainer_lime3 = fit(explainer_lime,queryPoint(3,:),6);
explainer_lime4 = fit(explainer_lime,queryPoint(4,:),6);
%Plot the coefficients of the simple models by using the object function plot.

tiledlayout(2,2)
ax1 = nexttile; plot(explainer_lime1);
ax2 = nexttile; plot(explainer_lime2);
ax3 = nexttile; plot(explainer_lime3);
ax4 = nexttile; plot(explainer_lime4);
ax1.TickLabelInterpreter = 'none';
ax2.TickLabelInterpreter = 'none';
ax3.TickLabelInterpreter = 'none';
ax4.TickLabelInterpreter = 'none';


%All simple models identify EBIT_TA, MVE_BVTD, RE_TA, and WC_TA as the four most important predictors. The positive coefficients for the predictors suggest that increasing the predictor values leads to an increase in the predicted scores in the simple models.

%For a categorical predictor, the plot function displays only the most important dummy variable of the categorical predictor. Therefore, each bar graph displays a different dummy variable.

%Compute Shapley Values

%The Shapley value of a predictor for a query point explains the deviation of the predicted score for the query point from the average score, due to the predictor. Create a shapley object using tblX_AAA so that shapley computes the expected contribution based on the samples for 'AAA'.

explainer_shapley = shapley(blackbox,tblX_AAA);
%Compute the Shapley values for the query points by using the object function fit.

explainer_shapley1 = fit(explainer_shapley,queryPoint(1,:));
explainer_shapley2 = fit(explainer_shapley,queryPoint(2,:));
explainer_shapley3 = fit(explainer_shapley,queryPoint(3,:));
explainer_shapley4 = fit(explainer_shapley,queryPoint(4,:));
%Plot the Shapley values by using the object function plot.

tiledlayout(2,2)
nexttile
plot(explainer_shapley1)
nexttile
plot(explainer_shapley2)
nexttile
plot(explainer_shapley3)
nexttile
plot(explainer_shapley4)

在这里插入图片描述

五、高斯回归模型

%Train GPR Model

%Load the carbig data set, which contains measurements of cars made in the 1970s and early 1980s.

load carbig
%Create a table containing the predictor variables Acceleration, Cylinders, and so on

tbl = table(Acceleration,Cylinders,Displacement,Horsepower,Model_Year,Weight);
%Train a GPR model of the response variable MPG by using the fitrgp function. Specify KernelFunction as 'ardsquaredexponential' to use the squared exponential kernel with a separate length scale per predictor.

blackbox = fitrgp(tbl,MPG,'ResponseName','MPG','CategoricalPredictors',[2 5], ...
    'KernelFunction','ardsquaredexponential');
%blackbox is a RegressionGP model.

%Use Model-Specific Interpretability Features

%You can compute predictor weights (predictor importance) from the learned length scales of the kernel function used in the model. The length scales define how far apart a predictor can be for the response values to become uncorrelated. Find the normalized predictor weights by taking the exponential of the negative learned length scales.

sigmaL = blackbox.KernelInformation.KernelParameters(1:end-1); % Learned length scales
weights = exp(-sigmaL); % Predictor weights
weights = weights/sum(weights); % Normalized predictor weights
%Create a table containing the normalized predictor weights, and use the table to create horizontal bar graphs. To display an existing underscore in any predictor name, change the TickLabelInterpreter value of the axes to 'none'.

tbl_weight = table(weights,'VariableNames',{'Predictor Weight'}, ...
    'RowNames',blackbox.ExpandedPredictorNames);
tbl_weight = sortrows(tbl_weight,'Predictor Weight');
b = barh(categorical(tbl_weight.Row,tbl_weight.Row),tbl_weight.('Predictor Weight'));
b.Parent.TickLabelInterpreter = 'none'; 
xlabel('Predictor Weight')
ylabel('Predictor')


%The predictor weights indicate that multiple dummy variables for the categorical predictors Model_Year and Cylinders are important.

%Specify Query Point

%Find the observations whose MPG values are smaller than the 0.25 quantile of MPG. From the subset, choose four query points that do not include missing values.

rng('default') % For reproducibility
idx_subset = find(MPG < quantile(MPG,0.25));
tbl_subset = tbl(idx_subset,:);
queryPoint = datasample(rmmissing(tbl_subset),4,'Replace',false)
   

%Use LIME with Tree Simple Models

%Explain the predictions for the query points using lime with decision tree simple models. lime generates a synthetic data set and fits a simple model to the synthetic data set.

%Create a lime object using tbl_subset so that lime generates a synthetic data set using the subset instead of the entire data set. Specify SimpleModelType as 'tree' to use a decision tree simple model.

explainer_lime = lime(blackbox,tbl_subset,'SimpleModelType','tree');
%The default value of DataLocality for lime is 'global', which implies that, by default, lime generates a global synthetic data set and uses it for any query points. lime uses different observation weights so that weight values are more focused on the observations near the query point. Therefore, you can interpret each simple model as an approximation of the trained model for a specific query point.

%Fit simple models for the four query points by using the object function fit. Specify the third input (the number of important predictors to use in the simple model) as 6. With this setting, the software specifies the maximum number of decision splits (or branch nodes) as 6 so that the fitted decision tree uses at most all predictors.

explainer_lime1 = fit(explainer_lime,queryPoint(1,:),6);
explainer_lime2 = fit(explainer_lime,queryPoint(2,:),6);
explainer_lime3 = fit(explainer_lime,queryPoint(3,:),6);
explainer_lime4 = fit(explainer_lime,queryPoint(4,:),6);
%Plot the predictor importance by using the object function plot.

tiledlayout(2,2)
ax1 = nexttile; plot(explainer_lime1);
ax2 = nexttile; plot(explainer_lime2);
ax3 = nexttile; plot(explainer_lime3);
ax4 = nexttile; plot(explainer_lime4);
ax1.TickLabelInterpreter = 'none';
ax2.TickLabelInterpreter = 'none';
ax3.TickLabelInterpreter = 'none';
ax4.TickLabelInterpreter = 'none';


%All simple models identify Displacement, Model_Year, and Weight as important predictors.

%Compute Shapley Values

%The Shapley value of a predictor for a query point explains the deviation of the predicted response for the query point from the average response, due to the predictor. Create a shapley object for the model blackbox using tbl_subset so that shapley computes the expected contribution based on the observations in tbl_subset.

explainer_shapley = shapley(blackbox,tbl_subset);
%Compute the Shapley values for the query points by using the object function fit.

explainer_shapley1 = fit(explainer_shapley,queryPoint(1,:));
explainer_shapley2 = fit(explainer_shapley,queryPoint(2,:));
explainer_shapley3 = fit(explainer_shapley,queryPoint(3,:));
explainer_shapley4 = fit(explainer_shapley,queryPoint(4,:));
%Plot the Shapley values by using the object function plot.

tiledlayout(2,2)
nexttile
plot(explainer_shapley1)
nexttile
plot(explainer_shapley2)
nexttile
plot(explainer_shapley3)
nexttile
plot(explainer_shapley4)

在这里插入图片描述

  • 6
    点赞
  • 27
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

随风飘摇的土木狗

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值