预训练模型mlm阅读理解任务

5 篇文章 0 订阅
4 篇文章 0 订阅

bert、roberta、ernie在中文mlm任务上效果查看

# -*- coding: utf-8 -*-
import torch
from transformers import BertTokenizer, BertForMaskedLM


def get_mlm_model(list_):
    ret = []
    for path in list_:
        tokenizer = BertTokenizer.from_pretrained(path)
        model = BertForMaskedLM.from_pretrained(path)
        ret.append((path, tokenizer, model))
    return ret


def gen_text(input_tx, tokenizer, model):
    tokenized_text = tokenizer.tokenize(input_tx)
    indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)

    tokens_tensor = torch.tensor([indexed_tokens])
    segments_tensors = torch.tensor([[0] * len(tokenized_text)])

    with torch.no_grad():
        outputs = model(tokens_tensor, token_type_ids=segments_tensors)
        predictions = outputs[0]

    predicted_index = [torch.argmax(predictions[0, i]).item() for i in range(0, (len(tokenized_text) - 1))]
    predicted_token = [tokenizer.convert_ids_to_tokens([predicted_index[x]])[0] for x in
                       range(1, (len(tokenized_text) - 1))]
    predicted_token = ''.join(predicted_token)
    print('raw token is:', input_tx)
    print('Predicted token is:', predicted_token)
    return predicted_token


if __name__ == '__main__':
    list_ = get_mlm_model([
        'bert-base-chinese',
        'nghuyong/ernie-1.0',
        'hfl/chinese-roberta-wwm-ext',
        # 'voidful/albert_chinese_tiny',  # albert有点问题,有些层没参数,使用的是初始化参数
    ])
    inputs = [
        "[CLS]清华大学[MASK][MASK]在哪里[SEP]",
        "[CLS] [MASK] [MASK] [MASK] 是中国神魔小说的经典之作,与《三国演义》《水浒传》《红楼梦》并称为中国古典四大名著。[SEP]",
        "[CLS][MASK][MASK][MASK]是中国神魔小说的经典之作,与《三国演义》《水浒传》《红楼梦》并称为中国古典四大名著。[SEP]",
        "[CLS]今天的股票会[MASK]吗[SEP]",
        "[CLS]今天的股票会[MASK][MASK][SEP]",
    ]
    for input_ in inputs:
        for name, tokenizer, model in list_:
            print(name)
            gen_text(input_, tokenizer, model)
            print()

结果
bert-base-chinese
raw token is: [CLS]清华大学[MASK][MASK]在哪里[SEP]
Predicted token is: 。华大学校址在哪里

nghuyong/ernie-1.0
raw token is: [CLS]清华大学[MASK][MASK]在哪里[SEP]
Predicted token is: 清华大学大华在哪里

hfl/chinese-roberta-wwm-ext
raw token is: [CLS]清华大学[MASK][MASK]在哪里[SEP]
Predicted token is: 清华大学究底在哪里

bert-base-chinese
raw token is: [CLS] [MASK] [MASK] [MASK] 是中国神魔小说的经典之作,与《三国演义》《水浒传》《红楼梦》并称为中国古典四大名著。[SEP]
Predicted token is: 《庸》是中国神魔小说的经典之作,与《三国演义》《水浒传》《红楼梦》并称为中国古典四大名著。

nghuyong/ernie-1.0
raw token is: [CLS] [MASK] [MASK] [MASK] 是中国神魔小说的经典之作,与《三国演义》《水浒传》《红楼梦》并称为中国古典四大名著。[SEP]
Predicted token is: 西游记是中国神魔小说的经典之作,与《三国演义》《水浒传》《红楼梦》并称为中国古典四大名著。

hfl/chinese-roberta-wwm-ext
raw token is: [CLS] [MASK] [MASK] [MASK] 是中国神魔小说的经典之作,与《三国演义》《水浒传》《红楼梦》并称为中国古典四大名著。[SEP]
Predicted token is: 西游梦是中国神魔小说的经典之作,与《三国演义》《水浒传》《红楼梦》并称为中国古典四大名著。

bert-base-chinese
raw token is: [CLS][MASK][MASK][MASK]是中国神魔小说的经典之作,与《三国演义》《水浒传》《红楼梦》并称为中国古典四大名著。[SEP]
Predicted token is: 《庸》是中国神魔小说的经典之作,与《三国演义》《水浒传》《红楼梦》并称为中国古典四大名著。

nghuyong/ernie-1.0
raw token is: [CLS][MASK][MASK][MASK]是中国神魔小说的经典之作,与《三国演义》《水浒传》《红楼梦》并称为中国古典四大名著。[SEP]
Predicted token is: 西游记是中国神魔小说的经典之作,与《三国演义》《水浒传》《红楼梦》并称为中国古典四大名著。

hfl/chinese-roberta-wwm-ext
raw token is: [CLS][MASK][MASK][MASK]是中国神魔小说的经典之作,与《三国演义》《水浒传》《红楼梦》并称为中国古典四大名著。[SEP]
Predicted token is: 西游梦是中国神魔小说的经典之作,与《三国演义》《水浒传》《红楼梦》并称为中国古典四大名著。

bert-base-chinese
raw token is: [CLS]今天的股票会[MASK]吗[SEP]
Predicted token is: 。天的股票会跌?

nghuyong/ernie-1.0
raw token is: [CLS]今天的股票会[MASK]吗[SEP]
Predicted token is: 今天的股票会涨吗

hfl/chinese-roberta-wwm-ext
raw token is: [CLS]今天的股票会[MASK]吗[SEP]
Predicted token is: 今天的股票会涨吗

bert-base-chinese
raw token is: [CLS]今天的股票会[MASK][MASK][SEP]
Predicted token is: 。天的。票会吗?

nghuyong/ernie-1.0
raw token is: [CLS]今天的股票会[MASK][MASK][SEP]
Predicted token is: 今天的股票会怎样

hfl/chinese-roberta-wwm-ext
raw token is: [CLS]今天的股票会[MASK][MASK][SEP]
Predicted token is: 今天的股票会涨吗


 

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值