wobert用法记录

# 安装需要的包
# pip install git+https://github.com/JunnYu/WoBERT_pytorch.git
# pip install jieba


import jieba
str1 = "国家食药监管总局近日发布《食品召回管理办法》,明确:食用后已经或可能导致严重健康损害甚至死亡的,属一级召回,食品生产者应在知悉食品安全风险后24小时内启动召回,且自公告发布之日起10个工作日内完成召回。"
seg_list = jieba.cut(str1, cut_all=False)
print(*seg_list)



import torch
from transformers import BertForMaskedLM as WoBertForMaskedLM
from wobert import WoBertTokenizer

pretrained_model_or_path_list = [
    "junnyu/wobert_chinese_plus_base", "junnyu/wobert_chinese_base"
]
for path in pretrained_model_or_path_list:
    text = "国家食药监管总局近日发布《食品召回管理办法》,明确:食用后已经或可能导致严重健康损害甚至死亡的,属一级召回,食品生产者应在知悉食品安全风险后24小时内启动召回,且自公告发布之日起10个工作日内完成召回。"
    tokenizer = WoBertTokenizer.from_pretrained(path)
    model = WoBertForMaskedLM.from_pretrained(path)
    inputs = tokenizer.tokenize(text, return_tensors="pt")
    print(inputs)
    outputs_sentence =tokenizer.convert_tokens_to_ids(inputs)
    print(outputs_sentence)
    outputs =tokenizer.convert_ids_to_tokens(outputs_sentence, skip_special_tokens=True)
    print(outputs)



import jieba
str1 = "国家食药监管总局近日发布《食品召回管理办法》,明确:食用后已经或可能导致严重健康损害甚至死亡的,属一级召回,食品生产者应在知悉食品安全风险后24小时内启动召回,且自公告发布之日起10个工作日内完成召回。"
seg_list = list(jieba.cut(str1, cut_all=False))
print(seg_list)

with open(r'/kaggle/input/stopwords/hit_stopwords.txt','r') as fr:
    stopwords_list = fr.read().split()
token_list = []
for i in seg_list:
    if i in stopwords_list:
        continue
    token_list.append(i)
print(token_list)


import torch
from transformers import BertForMaskedLM as WoBertForMaskedLM
from wobert import WoBertTokenizer

pretrained_model_or_path_list = [
    "junnyu/wobert_chinese_plus_base", "junnyu/wobert_chinese_base"
]
for path in pretrained_model_or_path_list:
    text = "国家食药监管总局近日发布《食品召回管理办法》,明确:食用后已经或可能导致严重健康损害甚至死亡的,属一级召回,食品生产者应在知悉食品安全风险后24小时内启动召回,且自公告发布之日起10个工作日内完成召回。"
    tokenizer = WoBertTokenizer.from_pretrained(path)
    model = WoBertForMaskedLM.from_pretrained(path)
    inputs = tokenizer(text, return_tensors="pt")
    with torch.no_grad():
        outputs = model(**inputs).logits[0]
    outputs_sentence = ""
    for i, id in enumerate(tokenizer.encode(text)):
        if id == tokenizer.mask_token_id:
            tokens = tokenizer.convert_ids_to_tokens(outputs[i].topk(k=5)[1])
            outputs_sentence += "[" + "||".join(tokens) + "]"
        else:
            outputs_sentence += "".join(
                tokenizer.convert_ids_to_tokens([id],
                                                skip_special_tokens=True))
    print(outputs_sentence)
# RoFormer    今天[天气||天||心情||阳光||空气]很好,我[想||要||打算||准备||喜欢]去公园玩。
# PLUS WoBERT 今天[天气||阳光||天||心情||空气]很好,我[想||要||打算||准备||就]去公园玩。
# WoBERT      今天[天气||阳光||天||心情||空气]很好,我[想||要||就||准备||也]去公园玩。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值