使用huggingface全家桶(transformers, datasets)实现一条龙BERT训练(trainer)和预测(pipeline)

使用huggingface全家桶(transformers, datasets)实现一条龙BERT训练(trainer)和预测(pipeline)

huggingface的transformers在我写下本文时已有39.5k star,可能是目前最流行的深度学习库了,而这家机构又提供了datasets这个库,帮助快速获取和处理数据。这一套全家桶使得整个使用BERT类模型机器学习流程变得前所未有的简单。

不过,目前我在网上没有发现比较简单的关于整个一套全家桶的使用教程。所以写下此文,希望帮助更多人快速上手。

这里,我们以AGNews新闻分类任务为例,演示整套流程的实现。

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2"  # 在此我指定使用2号GPU,可根据需要调整
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from transformers import Trainer, TrainingArguments
from transformers import pipeline
from datasets import load_dataset

使用datasets读取数据集

下面的代码读取原始数据集的train部分的前40000条作为我们的训练集,40000-50000条作为开发集(只使用这个子集已经可以训出不错的模型,并且可以让训练时间更短),原始的测试集作为我们的测试集。

train_dataset = load_dataset("ag_news", split="train[:40000]")
dev_dataset = load_dataset("ag_news", split="train[40000:50000]")
test_dataset = load_dataset("ag_news", split="test")
print(train_dataset)
print(dev_dataset)
print(test_dataset)
Dataset({
    features: ['text', 'label'],
    num_rows: 40000
})
Dataset({
    features: ['text', 'label'],
    num_rows: 10000
})
Dataset({
    features: ['text', 'label'],
    num_rows: 7600
})

原始数据集包含text和label两个字段

train_dataset.features
{'text': Value(dtype='string', id=None),
 'label': ClassLabel(num_classes=4, names=['World', 'Sports', 'Business', 'Sci/Tech'], names_file=None, id=None)}

由于bert模型期望得到的标签的字段为labels而原始数据集中的名字是label,所以做一下调整。

下面的代码把label字段复制到labels

train_dataset = train_dataset.map(lambda examples: {
   'labels': examples['label']}, batched=True)
train_dataset[0]
{'label': 2,
 'labels': 2,
 'text': "Wall St. Bears Claw Back Into the Black (Reuters) Reuters - Short-sellers, Wall Street's dwindling\\band of ultra-cynics, are seeing green aga
  • 20
    点赞
  • 66
    收藏
    觉得还不错? 一键收藏
  • 4
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值