多标签文本分类中使用captum

captum在官网上给出了文本分类的例子(Captum · Model Interpretability for PyTorch),这是一个二分类模型,但对于多分类模型,却没有给出例子来,经过自己的实践,完成了多标签文本分类的可解释模型。其实主要就是lig.attribute这个函数,关于函数的说明,详见:Lib\site-packages\captum\attr\_core\layer\layer_integrated_gradients.py这个文件中的attribute函数,函数的参数里面都有说明
摘要由CSDN通过智能技术生成

captum在官网上给出了文本分类的例子(Captum · Model Interpretability for PyTorch),这是一个二分类模型,但对于多分类模型,却没有给出例子来,经过自己的实践,完成了多标签文本分类的可解释模型。

其实主要就是lig.attribute这个函数,关于函数的说明,详见:Lib\site-packages\captum\attr\_core\layer\layer_integrated_gradients.py这个文件中的attribute函数,函数的参数里面都有说明,inputs是一个tensor格式数据或tuple格式数据,这和模型forword中的输入要一致,一般我们直接传入token_ids,有可能还会有token_mask等,我们可以借助additional_forward_args这个参数,如果使用additional_forward_args,模型中forword也需要有相关的参数。若不使用additional_forward_args这个参数,则inputs需要传入tuple。

lig.attribute((input_ids,decoder_input_ids), baselines=(reference_indices,dec_reference_indices),.....)

对于多分类,一定要指定target这个参数,这个参数要和预测到的类别一致,比如预测到的类别为5,那么target=5。

另外一点就是模型的输出了,要是用captum解释多分类,模型的输出是要改变的,比如有10个类别,模型的输出为 1*10。表示10个类别是否存在的概率,但lig.attribute函数中有个参数为n_steps,默认为50,所有模型的输出也要为 n_steps*10。简单的改写为:

output = output.reshape(-1,num_class)

基本就这些了......

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值