Transformer预训练任务总结
文章目录
1、介绍
这篇文章总结了基于Transformer库构建自己的任务的方案;更多信息可参考transformer文档。
总体来说,transformer接口提供了两种方法(pipline方法
和use model方法
)来便捷的执行各项任务,包括序列分类,QA,NER等;所有的这些方法,都是基于预训练模型执行的,你可以基于预训练模型来构建自己的模型,并在自己的数据集上进行微调;
下面的例子将分别使用pipline和use model两种机制来介绍使用方法;
- piplines: 非常易于使用的抽象接口,你只需要两行代码就能构建自己的分类器;
- use model: 抽象更少,灵活性和功能更强,建议使用这种方法来构建自己的模型。
下面通过不同任务的例子来总结使用方法,所有的例子都需要预先导入预训练的模型,这些模型你可以在hugging face下载,这里也有一些官方例子可供学习;
hugging Face: Models - Hugging Face
examples: examples
2、例子总结
本篇总结的所有的例子总结都是torch版本,tf版本参见源文档;
2.1、序列分类
🐈 pipline方法:
这里使用一个pipline方法,构建一个情感分析的2分类器,用于判别描述文本类别;
from transformers import pipeline
classifier = pipeline(task="sentiment-analysis", model="pre_trains/distilbert-base-uncased-finetuned-sst-2-english")
result = classifier("I hate you")[0]
print(f"label: {
result['label']}, with score: {
round(result['score'], 4)}")
result = classifier("I love you")[0]
print(f"label: {
result['label']}, with score: {
round(result['score'], 4)}")
🐕 use model方法
这里使用use model机制来做序列分类,判断两个句子是否为转述;流程包括以下几个步骤
1)从预训练模型构建一个tokenizer以及model
2)使用tokenizer编码文本为输入编码
3)通过模型计算,得到输出
4)softmax计算类别;
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
pre_train_path = "./pre_trains/distilbert-base-uncased-finetuned-sst-2-english"
tokenizer = AutoTokenizer.from_pretrained(pre_train_path)
model = AutoModelForSequenceClassification.from_pretrained(pre_train_path)
classes_name = ["not paraphrase", "is paraphrase"]
seq0 = "The company HuggingFace is based in New York City"
seq1 = "Apples are especially bad for your health"
seq2 = "HuggingFace's headquarters are situated in Manhattan"
inputs = [[seq0, seq2],
[seq0, seq1]]
para = tokenizer(inputs, return_tensors="pt", padding=True)
cls_lgt = model(**para).logits
cls_class = torch.softmax(cls_lgt, dim=1).tolist()
# output
for i, j in zip(cls_class, inputs):
print(j)
print('结果'.center(20, '='))
for value, c_name in zip(i, classes_name):
print(f'类别:{
c_name} -- 置信度:{
value}')
输出
['The company HuggingFace is based in New York City', "HuggingFace's headquarters are situated in Manhattan"]
=========结果=========
类别:not paraphrase -- 置信度:0.02561321295797825
类别:is paraphrase -- 置信度:0.9743868112564087
['The company HuggingFace is based in New York City', 'Apples are especially bad for your health']
=========结果=========
类别:not paraphrase -- 置信度:0.9997486472129822
类别:is paraphrase -- 置信度:0.00025137700140476227
2.2、EQA
EQA:Extractive Question Answering, 抽取式问答,任务在给出的问题文本当中抽取指定字段构建答案;SQuAD数据集是该任务的典型数据集,如果你想在该数据集上微调,可以利用run_qa.py脚本去训练;下面