【Hugging Face】transformers 库中的 BertForSequenceClassification:文本分类任务

Hugging Face transformers 库中的 BertForSequenceClassification

BertForSequenceClassificationtransformers 库中的 BERT 变体,专门用于 文本分类任务(如情感分析、垃圾邮件检测、主题分类等)。它在 BertModel 的基础上添加了一个 分类头(全连接层),用于将 BERT 编码的文本表示映射到 类别标签


1. 为什么使用 BertForSequenceClassification

BERT 作为 通用的文本编码器,可以用于 分类任务,但 BertModel 仅输出 文本的隐藏状态,如果要做分类,还需要额外的 全连接层 进行分类。

BertForSequenceClassification 直接在 BertModel 上添加了一个 线性分类层,使得 BERT 可以直接进行分类任务,适用于:

  • 情感分析(如 IMDB 数据集)
  • 垃圾邮件检测
  • 新闻分类
  • 法律/医学文本分类
  • 其他 NLP 分类任务

2. BertForSequenceClassification 的基本用法

2.1. 加载预训练的 BertForSequenceClassification

from transformers import BertTokenizer, BertForSequenceClassification

# 加载 tokenizer 和模型(2分类)
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)

2.2. 进行文本分类

假设我们有一个句子:

text = "Hugging Face is an amazing company!"

我们需要:

  1. 使用 tokenizer 进行分词
  2. 转换为张量
  3. 输入 BertForSequenceClassification 进行分类
  4. 获取预测类别
import torch

# 编码文本
inputs = tokenizer(text, return_tensors="pt")

# 进行预测
with torch.no_grad():
    outputs = model(**inputs)

# 获取 logits(原始分类分数)
logits = outputs.logits
predicted_class = torch.argmax(logits, dim=1).item()

print(f"Predicted class: {predicted_class}")

如果 num_labels=2,输出:

Predicted class: 1  # 可能表示“正面情感”

3. BertForSequenceClassification 的关键参数

参数作用
BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)指定分类类别数
outputs.logits获取模型的原始分类分数
torch.argmax(logits, dim=1)获取最终预测类别
tokenizer(text, return_tensors="pt")预处理文本,转换为张量

4. 训练 BertForSequenceClassification

如果你想 微调 BERT 进行文本分类,可以使用 Hugging Face Trainer

4.1. 加载数据集

我们使用 Hugging Face datasets 库加载 IMDB 影评数据:

from datasets import load_dataset

dataset = load_dataset("imdb")

这个数据集包含:

  • train(训练集)
  • test(测试集)
  • 每条数据有 textlabellabel=0 代表负面,label=1 代表正面。

4.2. 预处理数据

def preprocess(example):
    return tokenizer(example["text"], padding="max_length", truncation=True)

encoded_dataset = dataset.map(preprocess, batched=True)

4.3. 训练模型

使用 Trainer 进行微调:

from transformers import Trainer, TrainingArguments

training_args = TrainingArguments(
    output_dir="./results",
    evaluation_strategy="epoch",
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    save_steps=500,
    save_total_limit=2,
    num_train_epochs=3,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=encoded_dataset["train"],
    eval_dataset=encoded_dataset["test"],
)
trainer.train()

5. BertForSequenceClassification 的模型架构

BertForSequenceClassification 主要由 BERT 编码器 + 线性分类层 组成:

[BERT Encoder]
      ↓
[CLS Token Representation]
      ↓
[Linear Classifier] →  分类标签
  • BERT 提供文本的上下文表示
  • CLS token(第一个 token)作为文本的整体表示
  • Linear Classifier 进行分类

如果你想自定义分类头,可以使用:

from transformers import BertModel
import torch.nn as nn

class CustomBERTClassifier(nn.Module):
    def __init__(self, num_labels=2):
        super().__init__()
        self.bert = BertModel.from_pretrained("bert-base-uncased")
        self.classifier = nn.Linear(self.bert.config.hidden_size, num_labels)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids, attention_mask=attention_mask)
        logits = self.classifier(outputs.pooler_output)  # 使用 CLS token
        return logits

6. BertForSequenceClassification 适用于哪些任务?

任务适用情况
情感分析
垃圾邮件检测
新闻分类
文本匹配任务(如 NLI)适用,但推荐 BertForNextSentencePrediction
问答(Question Answering)不适用,推荐 BertForQuestionAnswering

7. BertForSequenceClassification vs BertModel

模型作用
BertModel仅包含 Transformer 编码器
BertForSequenceClassificationBertModel + 线性分类层
BertForTokenClassificationBertModel + 序列标注(NER)层
BertForQuestionAnsweringBertModel + 问答层

如果你的任务是 分类任务,推荐 BertForSequenceClassification


8. BertForSequenceClassification vs RoBERTa/GPT-2

模型任务类型适用情况
BertForSequenceClassification文本分类适用于所有文本分类任务
RoBERTaForSequenceClassification文本分类与 BERT 类似,但去掉了 NSP 任务
GPT-2语言生成适用于文本生成任务

如果你的任务是 文本分类,推荐 BertForSequenceClassification,如果是 文本生成,推荐 GPT-2


9. 总结

  1. BertForSequenceClassification 是 BERT 的文本分类版本,适用于 情感分析、垃圾邮件检测、新闻分类 等任务。
  2. BertModel 之上增加了 线性分类层,可以直接用于分类任务。
  3. 使用 Trainer 进行微调,可以轻松在 Hugging Face datasets 数据集上训练自定义分类模型。
  4. 适用于二分类和多分类任务,可以通过 num_labels=3 进行三分类等扩展。
  5. 可以替换默认的分类头,通过自定义 nn.Linear 进行调整。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

彬彬侠

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值