【总结】Transformer预训练任务总结

transformer预训练模型使用记录
摘要由CSDN通过智能技术生成

Transformer预训练任务总结

1、介绍

这篇文章总结了基于Transformer库构建自己的任务的方案;更多信息可参考transformer文档

总体来说,transformer接口提供了两种方法(pipline方法use model方法)来便捷的执行各项任务,包括序列分类,QA,NER等;所有的这些方法,都是基于预训练模型执行的,你可以基于预训练模型来构建自己的模型,并在自己的数据集上进行微调;

下面的例子将分别使用pipline和use model两种机制来介绍使用方法;

  • piplines: 非常易于使用的抽象接口,你只需要两行代码就能构建自己的分类器;
  • use model: 抽象更少,灵活性和功能更强,建议使用这种方法来构建自己的模型。

下面通过不同任务的例子来总结使用方法,所有的例子都需要预先导入预训练的模型,这些模型你可以在hugging face下载,这里也有一些官方例子可供学习;

hugging Face: Models - Hugging Face

examples: examples

2、例子总结

本篇总结的所有的例子总结都是torch版本,tf版本参见源文档;

2.1、序列分类

🐈 pipline方法:

这里使用一个pipline方法,构建一个情感分析的2分类器,用于判别描述文本类别;

from transformers import pipeline

classifier = pipeline(task="sentiment-analysis", model="pre_trains/distilbert-base-uncased-finetuned-sst-2-english")

result = classifier("I hate you")[0]
print(f"label: {
     result['label']}, with score: {
     round(result['score'], 4)}")

result = classifier("I love you")[0]
print(f"label: {
     result['label']}, with score: {
     round(result['score'], 4)}")

🐕 use model方法

这里使用use model机制来做序列分类,判断两个句子是否为转述;流程包括以下几个步骤

1)从预训练模型构建一个tokenizer以及model

2)使用tokenizer编码文本为输入编码

3)通过模型计算,得到输出

4)softmax计算类别;

from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch

pre_train_path = "./pre_trains/distilbert-base-uncased-finetuned-sst-2-english"

tokenizer = AutoTokenizer.from_pretrained(pre_train_path)
model = AutoModelForSequenceClassification.from_pretrained(pre_train_path)

classes_name = ["not paraphrase", "is paraphrase"]
seq0 = "The company HuggingFace is based in New York City"
seq1 = "Apples are especially bad for your health"
seq2 = "HuggingFace's headquarters are situated in Manhattan"

inputs = [[seq0, seq2],
          [seq0, seq1]]
para = tokenizer(inputs, return_tensors="pt", padding=True)

cls_lgt = model(**para).logits


cls_class = torch.softmax(cls_lgt, dim=1).tolist()

# output
for i, j in zip(cls_class, inputs):
    print(j)
    print('结果'.center(20, '='))
    for value, c_name in zip(i, classes_name):
        print(f'类别:{
     c_name} -- 置信度:{
     value}')

输出

['The company HuggingFace is based in New York City', "HuggingFace's headquarters are situated in Manhattan"]
=========结果=========
类别:not paraphrase -- 置信度:0.02561321295797825
类别:is paraphrase -- 置信度:0.9743868112564087
['The company HuggingFace is based in New York City', 'Apples are especially bad for your health']
=========结果=========
类别:not paraphrase -- 置信度:0.9997486472129822
类别:is paraphrase -- 置信度:0.00025137700140476227

2.2、EQA

EQA:Extractive Question Answering, 抽取式问答,任务在给出的问题文本当中抽取指定字段构建答案;SQuAD数据集是该任务的典型数据集,如果你想在该数据集上微调,可以利用run_qa.py脚本去训练;下面

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值