本文根据Chris McCormick的BERT微调教程进行优化并使其适应于数据集Quora Question Pairs里的判断问题对是否一致的任务。(文字部分大部分为原文的翻译)
原文博客地址:https://mccormickml.com/2019/07/22/BERT-fine-tuning/
原文colab地址:https://colab.research.google.com/drive/1pTuQhug6Dhl9XalKB0zUGf4FIdYFlpcX
本文项目地址:https://github.com/yxf975/pretraining_models_learning
前言
本文对删除了很多原英文博文中一些介绍性的内容,着重于如何实现基础的BERT微调方法。本解决方法不同于Chris McCormick的有以下几点:
- 使用的数据集为Quora问题对数据集
- 添加了多gpu运行的选择
- 将部分代码封装进了函数中,方便使用
- 添加了预测部分
具体对于BERT等预训练模型的原理的理解,我会单独创建一个话题,让我们直接开始吧!
准备工作
检查GPU
为了让 torch 使用 GPU,我们需要识别并指定 GPU 作为设备。稍后,在我们的训练循环中,我们将把数据加载到设备上。
import torch
# If there's a GPU available...
if torch.cuda.is_available():
# Tell PyTorch to use the GPU.
device = torch.device("cuda")
n_gpu = torch.cuda.device_count()
print('There are %d GPU(s) available.' % n_gpu)
print('We will use the GPU:', [torch.cuda.get_device_name(i) for i in range(n_gpu)])
# If not...
else:
print('No GPU available, using the CPU instead.')
device = torch.device("cpu")
安装Transformer库
目前,Hugging Face的Transformer库似乎是最被广泛接受的、最强大的与BERT合作的pytorch接口。除了支持各种不同的预先训练好的变换模型外,该库还包含了这些模型的预构建修改,适合你的特定任务。例如,在本教程中,我们将使用BertForSequenceClassification
。
该库还包括用于标记分类、问题回答、下句预测等的特定任务类。使用这些预建的类可以简化为您的目的修改BERT的过程。
!pip install transformers
加载Quora Question Pairs数据
数据集在kaggle官网上,注册登录即可下载,下载地址:https://www.kaggle.com/c/quora-question-pairs 。另外本人在google drive上也共享了数据集,下载地址:https://drive.google.com/drive/folders/1kFkte0Kt2xLe6Ykl4O4_TrL2iCzorOYk
Quora Question Pairs数据集介绍
这个数据集针对于Quora平台,很多人在Quora上会提出类似措辞的问题。具有相同意图的多个问题可能会导致搜寻者花费更多时间来寻找问题的最佳答案,并使作者感到他们需要回答同一问题的多个版本。
该任务需要对问题对是否重复进行分类,从而解决自然语言处理问题。这样做将使查找问题的高质量答案变得更加容易,从而为Quora的作家,搜寻者和读者带来了更好的体验。
pandas加载数据
import pandas as pd
import numpy as np
# Load the dataset into a pandas dataframe.
train_data = pd.read_csv("./train.csv", index_col="id",nrows=10000)
train_data.head(6)
这里我显示6行,因为到第六行才有个正样本。
id | qid1 | qid2 | question1 | question2 | is_duplicate |
---|---|---|---|---|---|
0 | 1 | 2 | What is the step by step guide to invest in share market in india? | What is the step by step guide to invest in share market? | 0 |
1 | 3 | 4 | What is the story of Kohinoor (Koh-i-Noor) Diamond? | What would happen if the Indian government stole the Kohinoor (Koh-i-Noor) diamond back? | 0 |
2 | 5 | 6 | How can I increase the speed of my internet connection while using a VPN? | How can Internet speed be increased by hacking through DNS? | 0 |
3 | 7 | 8 | Why am I mentally very lonely? How can I solve it? | Find the remainder when [math]23^{24}[/math] is divided by 24,23? | 0 |
4 | 9 | 10 | Which one dissolve in water quikly sugar, salt, methane and carbon di oxide? | Which fish would survive in salt water? | 0 |
5 | 11 | 12 | Astrology: I am a Capricorn Sun Cap moon and cap rising…what does that say about me? | I’m a triple Capricorn (Sun, Moon and ascendant in Capricorn) What does this say about me? | 1 |
我们实际关心的三个属性是"question1",“question1"和它们的标签"is_duplicate”,这个标签被称为"是否重复"(0=不重复,1=重复)。
训练集验证集拆分
把我们的训练集分成 80% 用于训练,20% 用于验证。
from sklearn.model_selection import train_test_split
# train_validation data split
X_train, X_val, y_train, y_val = train_test_split(train_data[["question1", "question2"]], train_data["is_duplicate"], test_size=0.2, random_state=405633)
Tokenization & Input 格式化
BERT Tokenizer
from transformers import BertTokenizer
# load bert tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)
查看数据中句子的最长长度
#calculate the maximum sentence length
max_len = 0
for _, row in train_data.iterrows():
max_len = max(max_len, len(tokenizer(row['question1'],row['question2'])["input_ids"]))
print("max token length of the input:", max_len)
# set the maximum token length
max_length = pow(2,int(np.log2(max_len)+1)