BERT 多分类实战:从训练到评估的完整指南

BERT 多分类实战

项目介绍

在自然语言处理(NLP)领域,BERT 已成为强大的预训练模型,适用于各类文本分类任务。在这篇技术文章中,我们将展示如何使用 Hugging Face 的 transformers.Trainer,构建并训练一个基于 BERT 的多分类模型,并对其进行评估。

准备工作

在开始动手前,我们需要准备以下内容:

  1. 必要的库transformers(Hugging Face提供的BERT库)、torch(PyTorch框架)、sklearn(常用的机器学习工具库)

  2. 数据集:为了演示,我们使用一个包含新闻文本的多分类数据集。每条新闻需要被分类为['World', 'Sports', 'Business', 'Sci/Tech']等类别。

dataset = load_dataset("fancyzhx/ag_news")

数据预处理

该数据集为英文数据集,故使用bert-base-uncased 模型。

model_name = "bert-base-uncased"

tokenizer = AutoTokenizer.from_pretrained(model_name)

def tokenize_func(tokenizer):

    def _func(item):
        return tokenizer(
            item["text"],
            max_length=512,
            truncation=True,
        )

    return _func

new_dataset = dataset.map(tokenize_func(tokenizer=tokenizer))

训练

构建基于BERT的多分类模型。Hugging Face的transformers库提供了预训练的BERT模型,并且我们可以很方便地在这个基础上微调。

bert = AutoModelForSequenceClassification.from_pretrained(
    model_name,
    trust_remote_code=True,
    num_labels=4,
)

在这里,我们使用了BertForSequenceClassification,它是一个预训练的BERT模型,同时包含一个序列分类头,用于文本分类任务。

args = TrainingArguments(
    output_dir=output_dir,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    save_total_limit=3,
    learning_rate=2e-5,
    num_train_epochs=epoch,
    weight_decay=0.01,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=16,
    # logging_steps=16,
    save_safetensors=True,
    overwrite_output_dir=True,
    load_best_model_at_end=True,
)

trainer = Trainer(
    model=self.model,
    args=args,
    train_dataset=self.train_dataset,
    eval_dataset=self.eval_dataset,
    data_collator=self.data_collator,
    compute_metrics=self.compute_metrics,
    tokenizer=tokenizer,
)

使用了Trainer来简化训练流程,定义了训练的相关参数,比如训练轮次、batch size等。

评估

在测试集上的评估如下所示:

from sklearn.metrics import classification_report, confusion_matrix

使用 classification_report 评估模型的预测情况。

在这里插入图片描述

[注]: support 代表数据样本的数量。

使用 confusion_matrix 计算混淆矩阵。最后在测试集上,预测的混淆矩阵如下所示,利用 confusion_matrix计算混淆矩阵。

在这里插入图片描述

在多分类问题中,模型评估是非常重要的环节,混淆矩阵(Confusion Matrix)作为一种直观的评估工具,被广泛应用于分类问题。

混淆矩阵是一个 n x n 的矩阵,其中 n 是分类任务中的类别数量。它展示了模型在每个分类上的真实标签预测标签的分布情况。矩阵的每一行代表实际的类,列代表模型预测的类。

对角线元素表示模型正确分类的数量,它们表示真实类别和预测类别完全一致的样本数。

因此它直观展示出了每个类别的分类情况,帮助我们识别模型在不同类别上的表现差异。

开源

项目代码:https://github.com/JieShenAI/wechat/blob/main/24/09/多分类实战/bert_多分类实战.ipynb

总结

在这篇文章中,展示了如何基于BERT进行多分类任务的完整流程。从数据预处理到模型训练再到评估。

接下来,你可以尝试用其他数据集来训练你的模型,甚至调整BERT的预训练模型或使用不同的超参数进行微调。无论是文本分类、情感分析,还是其他NLP任务。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

jieshenai

为了遇见更好的文章

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值