基于ModelScope获取[MASK]位置的top k个候选token

起因是我的任务需求涉及到使用PoNet这个模型来进行完形填空任务,需要根据logits值获取填入[MASK]位置的词汇表中的top k个候选token

我本身只用过huggingface,按huggingface的方式掩码语言模型一般都会提供有xxxxForMaskedLM这样一个接口,我可以方便的获取词汇表中每个token被填入的logit值

但是PoNet这个模型比较特殊,他在huggingface上没有提供xxxxForMaskedLM,只提供了PoNetForSequenceClassification接口

于是我来到ModelScope,ModelScope虽然有一个fill-mask的pipeline,但是这个pipeline封装得太好了,这个pipeline返回的就是[MASK]位置被填入logit值最高的token后的字符串,我无法直接获取到每个token的logit,我看官方文档也没有提到有没有类似xxxxForMaskedLM的接口,无奈只能自己摸索。

几个小时后终于给我摸索出来了

其实ModelScope也是有类似xxxxForMaskedLM的接口的,只不过藏得比较深,而且官网上的文档没有提到

from modelscope.models.nlp.task_models import ModelForFillMask
from modelscope.models.nlp.ponet.tokenization import PoNetTokenizer

 这时候只要像用huggingface那样用ModelScope就可以了

plm = ModelForFillMask.from_pretrained("damo/nlp_ponet_fill-mask_english-base", trust_remote_code=True)
plm.cuda()
tokenizer = PoNetTokenizer.from_pretrained("damo/nlp_ponet_fill-mask_english-base", trust_remote_code=True)

但是这么写又有问题了

我寻思我不是用ModelScope加载吗?为啥下载tokenizer的时候链接变成了从huggingface下载???

 于是我就把tokenizer名称换成在huggingface上的名称

plm = ModelForFillMask.from_pretrained("damo/nlp_ponet_fill-mask_english-base", trust_remote_code=True)
plm.cuda()
tokenizer = PoNetTokenizer.from_pretrained("chtan/ponet-base-uncased", trust_remote_code=True)

这样就成功了

完整代码:

from modelscope.models.nlp.task_models import ModelForFillMask
from modelscope.models.nlp.ponet.tokenization import PoNetTokenizer
import torch

plm = ModelForFillMask.from_pretrained("damo/nlp_ponet_fill-mask_english-base", trust_remote_code=True)
plm.cuda()
tokenizer = PoNetTokenizer.from_pretrained("chtan/ponet-base-uncased", trust_remote_code=True)

inputs = tokenizer("I want a cup of " + tokenizer.mask_token + ", please.", return_tensors="pt").to("cuda")
mask_token_index = (inputs.input_ids == tokenizer.mask_token_id)[0].nonzero(as_tuple=True)[0].item()

logits = plm(**inputs).logits
mask_logits = logits[0][mask_token_index]
topk_ids = torch.topk(mask_logits, k=5, sorted=True, largest=True)
topk_tokens = tokenizer.convert_ids_to_tokens(topk_ids.indices)
print(topk_tokens)

运行结果:

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值