PyTorch环境下对BERT进行Fine-tuning

本文详述了在PyTorch环境中对预训练的BERT模型进行微调的过程,以适应Quora Question Pairs数据集的问题对一致性判断任务。内容包括数据准备、模型加载、训练和预测,提供了完整的代码实现。
摘要由CSDN通过智能技术生成

本文根据Chris McCormick的BERT微调教程进行优化并使其适应于数据集Quora Question Pairs里的判断问题对是否一致的任务。(文字部分大部分为原文的翻译)

原文博客地址:https://mccormickml.com/2019/07/22/BERT-fine-tuning/

原文colab地址:https://colab.research.google.com/drive/1pTuQhug6Dhl9XalKB0zUGf4FIdYFlpcX

本文项目地址:https://github.com/yxf975/pretraining_models_learning

前言

本文对删除了很多原英文博文中一些介绍性的内容,着重于如何实现基础的BERT微调方法。本解决方法不同于Chris McCormick的有以下几点:

  • 使用的数据集为Quora问题对数据集
  • 添加了多gpu运行的选择
  • 将部分代码封装进了函数中,方便使用
  • 添加了预测部分

具体对于BERT等预训练模型的原理的理解,我会单独创建一个话题,让我们直接开始吧!

准备工作

检查GPU

为了让 torch 使用 GPU,我们需要识别并指定 GPU 作为设备。稍后,在我们的训练循环中,我们将把数据加载到设备上。

import torch

# If there's a GPU available...
if torch.cuda.is_available():    

    # Tell PyTorch to use the GPU.    
    device = torch.device("cuda")
    n_gpu = torch.cuda.device_count()

    print('There are %d GPU(s) available.' % n_gpu)

    print('We will use the GPU:', [torch.cuda.get_device_name(i) for i in range(n_gpu)])

# If not...
else:
    print('No GPU available, using the CPU instead.')
    device = torch.device("cpu")
安装Transformer库

目前,Hugging Face的Transformer库似乎是最被广泛接受的、最强大的与BERT合作的pytorch接口。除了支持各种不同的预先训练好的变换模型外,该库还包含了这些模型的预构建修改,适合你的特定任务。例如,在本教程中,我们将使用BertForSequenceClassification

该库还包括用于标记分类、问题回答、下句预测等的特定任务类。使用这些预建的类可以简化为您的目的修改BERT的过程。

!pip install transformers

加载Quora Question Pairs数据

数据集在kaggle官网上,注册登录即可下载,下载地址:https://www.kaggle.com/c/quora-question-pairs 。另外本人在google drive上也共享了数据集,下载地址:https://drive.google.com/drive/folders/1kFkte0Kt2xLe6Ykl4O4_TrL2iCzorOYk

Quora Question Pairs数据集介绍

这个数据集针对于Quora平台,很多人在Quora上会提出类似措辞的问题。具有相同意图的多个问题可能会导致搜寻者花费更多时间来寻找问题的最佳答案,并使作者感到他们需要回答同一问题的多个版本。

该任务需要对问题对是否重复进行分类,从而解决自然语言处理问题。这样做将使查找问题的高质量答案变得更加容易,从而为Quora的作家,搜寻者和读者带来了更好的体验。

pandas加载数据
import pandas as pd
import numpy as np

# Load the dataset into a pandas dataframe.
train_data = pd.read_csv("./train.csv", index_col="id",nrows=10000)
train_data.head(6)

这里我显示6行,因为到第六行才有个正样本。

id qid1 qid2 question1 question2 is_duplicate
0 1 2 What is the step by step guide to invest in share market in india? What is the step by step guide to invest in share market? 0
1 3 4 What is the story of Kohinoor (Koh-i-Noor) Diamond? What would happen if the Indian government stole the Kohinoor (Koh-i-Noor) diamond back? 0
2 5 6 How can I increase the speed of my internet connection while using a VPN? How can Internet speed be increased by hacking through DNS? 0
3 7 8 Why am I mentally very lonely? How can I solve it? Find the remainder when [math]23^{24}[/math] is divided by 24,23? 0
4 9 10 Which one dissolve in water quikly sugar, salt, methane and carbon di oxide? Which fish would survive in salt water? 0
5 11 12 Astrology: I am a Capricorn Sun Cap moon and cap rising…what does that say about me? I’m a triple Capricorn (Sun, Moon and ascendant in Capricorn) What does this say about me? 1

我们实际关心的三个属性是"question1",“question1"和它们的标签"is_duplicate”,这个标签被称为"是否重复"(0=不重复,1=重复)。

训练集验证集拆分

把我们的训练集分成 80% 用于训练,20% 用于验证。

from sklearn.model_selection import train_test_split

# train_validation data split
X_train, X_val, y_train, y_val = train_test_split(train_data[["question1", "question2"]], train_data["is_duplicate"], test_size=0.2, random_state=405633)

Tokenization & Input 格式化

BERT Tokenizer
from transformers import BertTokenizer

# load bert tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)
查看数据中句子的最长长度
#calculate the maximum sentence length
max_len  = 0
for _, row in train_data.iterrows():
    max_len = max(max_len, len(tokenizer(row['question1'],row['question2'])["input_ids"]))

print("max token length of the input:", max_len)
    
# set the maximum token length
max_length = pow(2,int(np.log2(max_len)+1)
  • 5
    点赞
  • 32
    收藏
    觉得还不错? 一键收藏
  • 15
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 15
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值