Hugging Face transformers
库中的 BertForSequenceClassification
BertForSequenceClassification
是 transformers
库中的 BERT 变体,专门用于 文本分类任务(如情感分析、垃圾邮件检测、主题分类等)。它在 BertModel
的基础上添加了一个 分类头(全连接层),用于将 BERT 编码的文本表示映射到 类别标签。
1. 为什么使用 BertForSequenceClassification
?
BERT 作为 通用的文本编码器,可以用于 分类任务,但 BertModel
仅输出 文本的隐藏状态,如果要做分类,还需要额外的 全连接层 进行分类。
BertForSequenceClassification
直接在 BertModel
上添加了一个 线性分类层,使得 BERT 可以直接进行分类任务,适用于:
- 情感分析(如 IMDB 数据集)
- 垃圾邮件检测
- 新闻分类
- 法律/医学文本分类
- 其他 NLP 分类任务
2. BertForSequenceClassification
的基本用法
2.1. 加载预训练的 BertForSequenceClassification
from transformers import BertTokenizer, BertForSequenceClassification
# 加载 tokenizer 和模型(2分类)
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)
2.2. 进行文本分类
假设我们有一个句子:
text = "Hugging Face is an amazing company!"
我们需要:
- 使用
tokenizer
进行分词 - 转换为张量
- 输入
BertForSequenceClassification
进行分类 - 获取预测类别
import torch
# 编码文本
inputs = tokenizer(text, return_tensors="pt")
# 进行预测
with torch.no_grad():
outputs = model(**inputs)
# 获取 logits(原始分类分数)
logits = outputs.logits
predicted_class = torch.argmax(logits, dim=1).item()
print(f"Predicted class: {predicted_class}")
如果 num_labels=2
,输出:
Predicted class: 1 # 可能表示“正面情感”
3. BertForSequenceClassification
的关键参数
参数 | 作用 |
---|---|
BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2) | 指定分类类别数 |
outputs.logits | 获取模型的原始分类分数 |
torch.argmax(logits, dim=1) | 获取最终预测类别 |
tokenizer(text, return_tensors="pt") | 预处理文本,转换为张量 |
4. 训练 BertForSequenceClassification
如果你想 微调 BERT 进行文本分类,可以使用 Hugging Face Trainer
。
4.1. 加载数据集
我们使用 Hugging Face datasets
库加载 IMDB 影评数据:
from datasets import load_dataset
dataset = load_dataset("imdb")
这个数据集包含:
train
(训练集)test
(测试集)- 每条数据有
text
和label
,label=0
代表负面,label=1
代表正面。
4.2. 预处理数据
def preprocess(example):
return tokenizer(example["text"], padding="max_length", truncation=True)
encoded_dataset = dataset.map(preprocess, batched=True)
4.3. 训练模型
使用 Trainer
进行微调:
from transformers import Trainer, TrainingArguments
training_args = TrainingArguments(
output_dir="./results",
evaluation_strategy="epoch",
per_device_train_batch_size=8,
per_device_eval_batch_size=8,
save_steps=500,
save_total_limit=2,
num_train_epochs=3,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=encoded_dataset["train"],
eval_dataset=encoded_dataset["test"],
)
trainer.train()
5. BertForSequenceClassification
的模型架构
BertForSequenceClassification
主要由 BERT 编码器 + 线性分类层 组成:
[BERT Encoder]
↓
[CLS Token Representation]
↓
[Linear Classifier] → 分类标签
BERT
提供文本的上下文表示CLS token
(第一个 token)作为文本的整体表示Linear Classifier
进行分类
如果你想自定义分类头,可以使用:
from transformers import BertModel
import torch.nn as nn
class CustomBERTClassifier(nn.Module):
def __init__(self, num_labels=2):
super().__init__()
self.bert = BertModel.from_pretrained("bert-base-uncased")
self.classifier = nn.Linear(self.bert.config.hidden_size, num_labels)
def forward(self, input_ids, attention_mask):
outputs = self.bert(input_ids, attention_mask=attention_mask)
logits = self.classifier(outputs.pooler_output) # 使用 CLS token
return logits
6. BertForSequenceClassification
适用于哪些任务?
任务 | 适用情况 |
---|---|
情感分析 | 是 |
垃圾邮件检测 | 是 |
新闻分类 | 是 |
文本匹配任务(如 NLI) | 适用,但推荐 BertForNextSentencePrediction |
问答(Question Answering) | 不适用,推荐 BertForQuestionAnswering |
7. BertForSequenceClassification
vs BertModel
模型 | 作用 |
---|---|
BertModel | 仅包含 Transformer 编码器 |
BertForSequenceClassification | BertModel + 线性分类层 |
BertForTokenClassification | BertModel + 序列标注(NER)层 |
BertForQuestionAnswering | BertModel + 问答层 |
如果你的任务是 分类任务,推荐 BertForSequenceClassification
。
8. BertForSequenceClassification
vs RoBERTa/GPT-2
模型 | 任务类型 | 适用情况 |
---|---|---|
BertForSequenceClassification | 文本分类 | 适用于所有文本分类任务 |
RoBERTaForSequenceClassification | 文本分类 | 与 BERT 类似,但去掉了 NSP 任务 |
GPT-2 | 语言生成 | 适用于文本生成任务 |
如果你的任务是 文本分类,推荐 BertForSequenceClassification
,如果是 文本生成,推荐 GPT-2
。
9. 总结
BertForSequenceClassification
是 BERT 的文本分类版本,适用于 情感分析、垃圾邮件检测、新闻分类 等任务。- 在
BertModel
之上增加了线性分类层
,可以直接用于分类任务。 - 使用
Trainer
进行微调,可以轻松在 Hugging Facedatasets
数据集上训练自定义分类模型。 - 适用于二分类和多分类任务,可以通过
num_labels=3
进行三分类等扩展。 - 可以替换默认的分类头,通过自定义
nn.Linear
进行调整。