2020/11/30
SHAP分析Mnist数据集的Mask遮掩复现实验
实验介绍
论文:A Unified Approach to Interpreting ModelPredictions
论文地址:https://papers.nips.cc/paper/2017/file/8a20a8621978632d76c43dfd28b67767-Paper.pdf
在SHAP论文中的Experiments部分,提出了一个可解释模型的评估指标:比较不同算法找出的最重要的特征的实际影响
如下图所示,对20张图片的中shap_values最高的前几十个像素点进行了遮掩(替换为其他像素)后,与原来的模型分析结果进行对比,以此来说明SHAP分析模型找出的最重要的特征的实际影响是最大的。
实验步骤
实验大致分为三个部分:
1.本地复现Shap分析Mnist
下载mnist数据集,掌握如何用训练好的model对本地的image和label进行测试,然后对Shap_Mnist的代码进行修改,在用本地dataset完成基本的实验复现,能实现本地Test_loader的建立就行:
2.自己挑选数据测试训练好的model并提取shap_values参数
挑选出8组(随意)测试数据,用Mnist的大量train数据集进行model的训练,然后在原来Shap_Mnist的可视化部分代码进行修改,提取出shap_values的值:
3.根据shap_values参数进行遮掩实验并进行结果可视化
shap_values的维度是10*8*1*28*28,分别对应着0-9的十个数字、一个batch的图片数、通道数、图片像素点的行、图片像素点的列。
因此对于后面的28*28两维度对应一张图片的一个结果,将其压缩为784的tensor,然后用numpy.argsort从大到小进行排序返回index,挑选出前20个最大shap_values对应的index,并用index\28,index%28计算出像素点的坐标。
将test_images的对应坐标遮掩为128,然后重新进行shap.DeepExplainer(model, background)获取shap_values并进行可视化。
实验结果
20images、30images、20+30images在遮掩了100个点的实验结果(更多实验结果看文件):
实验总结
说实话,因为刚入门,python都写不流畅,简简单单的几个步骤,每一个步骤都试错了很久。特别是如何用自己的image和label创建dataset实属一直跳坑里,最后终于也是在同学的帮助下成功地完成了修正。本次实验就到可视化部分了,具体的分析留着在后面继续完成,后续还有LIME的遮掩实验复现以追求二者的区分。作为一个实验的复现,实验的结果基本上是板上钉钉————SHAP的结果要优于LIME,但是老师推荐我尝试,我也明白它的意义————能增强我对SHAP的理解以及对SHAP和LIME的区别有一个更深的认识,还有就是很大程度锻炼了我的一个测试能力,毕竟什么都只会看和自己亲历亲为完成实验还是有着本质上的区别,虽然仅仅是修修改改一点代码,但是我始终相信量变才能带来质变,厚积薄发,加油。
代码:https://download.csdn.net/download/ylwhxht/13587843
图片:https://download.csdn.net/download/ylwhxht/13587850