captum在官网上给出了文本分类的例子(Captum · Model Interpretability for PyTorch),这是一个二分类模型,但对于多分类模型,却没有给出例子来,经过自己的实践,完成了多标签文本分类的可解释模型。
其实主要就是lig.attribute这个函数,关于函数的说明,详见:Lib\site-packages\captum\attr\_core\layer\layer_integrated_gradients.py这个文件中的attribute函数,函数的参数里面都有说明,inputs是一个tensor格式数据或tuple格式数据,这和模型forword中的输入要一致,一般我们直接传入token_ids,有可能还会有token_mask等,我们可以借助additional_forward_args这个参数,如果使用additional_forward_args,模型中forword也需要有相关的参数。若不使用additional_forward_args这个参数,则inputs需要传入tuple。
lig.attribute((input_ids,decoder_input_ids), baselines=(reference_indices,dec_reference_indices),.....)
对于多分类,一定要指定target这个参数,这个参数要和预测到的类别一致,比如预测到的类别为5,那么target=5。
另外一点就是模型的输出了,要是用captum解释多分类,模型的输出是要改变的,比如有10个类别,模型的输出为 1*10。表示10个类别是否存在的概率,但lig.attribute函数中有个参数为n_steps,默认为50,所有模型的输出也要为 n_steps*10。简单的改写为:
output = output.reshape(-1,num_class)
基本就这些了......