使用Hugginface的Transformers库快速微调BERT等预训练模型,使其适应下游任务,本文以Quora问题对任务为例,对两问题表意是否一致进行预测
介绍
之前写了个微调BERT的入门教程,步骤比较多,我后来把它都封装成函数了,但用起来还是比较麻烦,而且有时候Transformer库里一些函数方法会改动,这就还要对代码进行调整。最近看Transformers的文档,发现它自己封装了一个Trainer
函数,很好用,分享一下
本文完整的jupyter notebook地址为:https://github.com/yxf975/PTMs_learning/blob/main/%E9%A2%84%E8%AE%AD%E7%BB%83%E6%A8%A1%E5%9E%8B%E5%BE%AE%E8%B0%83.ipynb
这次依然还是在Quora问题对数据集上进行,任务这里再重复一遍,就是给定的两个问题,预测其表达的意思是否是相似的。这个任务是建立在这样的应用场景下的,在Quora这个软件里,很多人可能回问相似或同样的问题,但往往每个问题都需要来对其重复解答,很不效率。因此如果我们有个能判定问题是否相似或一致的算法,那么我们就能给出之前相似问题的答案,提高效率。
安装相应的库和导入包
废话不多说,直接上代码。首先先安装相应的库,一个是熟知的Hugginface的transformers,另一个是Hugginface的nlp数据集库,意味着我们之后直接用这个库载入数据。
!pip install transformers
!pip install nlp
之后导入需要用的包
import torch
import pandas as pd
from sklearn.model_selection import train_test_split
from transformers import BertForSequenceClassification, Trainer, TrainingArguments,BertTokenizerFast
from nlp import load_dataset
from nlp import Dataset
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, log_loss
from scipy.special import softmax
载入数据集
这里有两种方情况:一个是载入nlp库里的数据集,还有一个就是使用本地的数据集,接下来对这两种情况都来尝试一下
载入nlp库里的数据集
载入数据集,使用nlp.load_dataset方法,具体完整的文档地址为https://huggingface.co/docs/datasets/v0.3.0/index.html。
# 从hugginface的数据集库中下载quora数据集
# 也可以全部读取完用dataset.train_test_split(test_size = 0.1)这样来读取
train_dataset = load_dataset('quora', split='train[:20%]') #读取前20%数据做训练集
validation_dataset = load_dataset('quora', split='train[30%:35%]') #读取30到35%之间数据用作验证集
test_dataset = load_dataset('quora', split='train[-10%:]')#取后10%用作测试集
接着导入预训练模型的分词器和模型,这里我用的还是BERT_BASE模型
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2,output_hidden_states = False)
预处理函数,这里是为了提取出每条样本中的问题文本并进行分词操作以及label,具体原数据集的形式可以自行查看,这里就不多说了。另外这里超参序列长度max_length
是在这部分设置的,为了跑的快一点我设置的很小,最大为512。
def preprocess_function(examples):
# Tokenize the texts
result = tokenizer([examples['questions'][i]['text'][0] for i in range(len(examples['questions']))],
[examples['questions'][i]['text'][1] for i in range(len(examples['questions']))],
padding=True, truncation=True, max_length=32)