使用 Bert 做文本分类,利用 Trainer 框架实现 二分类,事半功倍

简介

使用 AutoModelForSequenceClassification 导入Bert 模型。
很多教程都会自定义 损失函数,然后手动实现参数更新。
但本文不想手动微调,故使用 transformers 的 Trainer 自动微调。
人生苦短,我用框架,不仅可保证微调出的模型的效果,而且还省时间。

导包

import evaluate
import numpy as np
from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
)

import torch
from torch import nn

import os
os.environ['HTTP_PROXY'] = 'http://127.0.0.1:7890'
os.environ['HTTPS_PROXY'] = 'http://127.0.0.1:7890'

# AG_News 英文分类数据集
# ds = load_dataset("fancyzhx/ag_news")

## 中文分类数据集
ds = load_dataset("lansinuote/ChnSentiCorp")

数据集的详情如下:

DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 9600
    })
    validation: Dataset({
        features: ['text', 'label'],
        num_rows: 1200
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 1200
    })
})
ds["train"][0]
{'text': '选择珠江花园的原因就是方便,有电动扶梯直接到达海边,周围餐馆、食廊、商场、超市、摊位一应俱全。酒店装修一般,但还算整洁。 泳池在大堂的屋顶,因此很小,不过女儿倒是喜欢。 包的早餐是西式的,还算丰富。 服务吗,一般',
 'label': 1}

加载 Bert 模型

model_name = "bert-base-chinese"

tokenizer = AutoTokenizer.from_pretrained(
    model_name,
    trust_remote_code=True,
)

bert = AutoModelForSequenceClassification.from_pretrained(
    model_name,
    trust_remote_code=True,
    num_labels=2,
)

如果你无法联网的话,使用本地huggingface模型:

bert = AutoModelForSequenceClassification.from_pretrained(
    model_name,
    trust_remote_code=True,
    revision="c30a6ed22ab4564dc1e3b2ecbf6e766b0611a33f",
    local_files_only=True,
    num_labels=2,
)

查看 bert 分类模型的网络结构:

bert

在这里插入图片描述

如上图所示,Bert 的分类模型:在原生的 Bert 模型后,加了一个Linear

下述是数据集转换函数:

def tokenize_func(item):
    global tokenizer
    tokenized_inputs = tokenizer(
        item["text"],
        max_length=512,
        truncation=True,
    )
    return tokenized_inputs
tokenized_datasets = ds.map(
    tokenize_func,
    batched=True,
)

tokenized_datasets 的详情如下所示:

DatasetDict({
    train: Dataset({
        features: ['text', 'label', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 9600
    })
    validation: Dataset({
        features: ['text', 'label', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 1200
    })
    test: Dataset({
        features: ['text', 'label', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 1200
    })
})

Train

from transformers import TrainingArguments

args = TrainingArguments(
    "ChnSentiCorp_text_cls",
    eval_steps=8,
    evaluation_strategy="steps",
    save_strategy="epoch",
    save_total_limit=3,
    learning_rate=2e-5,
    num_train_epochs=3,
    weight_decay=0.01,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=16,
    logging_steps=8,
    save_safetensors=True,
    overwrite_output_dir=True,
    # load_best_model_at_end=True,
)

TrainingArguments 的参数解释点击查看下述文章:
LLM大模型之Trainer以及训练参数

from transformers import DataCollatorWithPadding

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
from transformers import Trainer

trainer = Trainer(
    model=bert,
    args=args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    data_collator=data_collator,
    # compute_metrics=compute_metrics,
    tokenizer=tokenizer,
)
trainer.train()

训练过程,在终端可以看见,训练和验证的损失值变化。
在这里插入图片描述

如果安装了 wandb,并且在系统环境变量中,进行了设置。

训练过程和评估过程的记录会自动上传到wandb中。

wandb

若你想使用 wandb,自行进行安装;个人强烈推荐,一劳永逸,这样就无需自己绘图展示模型的训练过程了。

在模型训练的过程,进入 wandb https://wandb.ai/home 看看模型的现在的训练的过程。
在这里插入图片描述

在这里插入图片描述

上图是在 wandb 网站看到的图,横轴是 epoch ,纵轴是 loss。
蓝色折线是在验证集上的损失,橙色折线是在训练集上的损失。

可以很直观的看到,在训练集上的loss 小于 在验证集上的 loss。

predict

训练完成的模型,使用 predict 方法,在测试集上预测。

predictions = trainer.predict(tokenized_datasets["test"])
preds = np.argmax(predictions.predictions, axis=-1)
preds

输出结果:

array([1, 0, 0, ..., 1, 1, 0])

预测结果评估

def eval_data(data):
    predictions = trainer.predict(data)
    preds = np.argmax(predictions.predictions, axis=-1)
    metric = evaluate.load("glue", "mrpc")
    return metric.compute(predictions=preds, references=predictions.label_ids)
eval_data(tokenized_datasets["test"])

输出结果:

{'accuracy': 0.9475, 'f1': 0.9478908188585607}

Bert 工具类

有些时候,在模型应用中,我们想快速使用Bert模型,不想过多关注Bert的训练细节。故提供了下述Bert的工具类。

对于提供给 BertCLS 的数据集,要提前做tokenize化。为了让 BertCLS 更通用化,不在其中内置 tokenize 处理数据集的函数。

from dataclasses import dataclass


@dataclass
class BertCLS:
    def __init__(
        self, model, tokenizer, train_dataset=None, eval_dataset=None, output_dir="output", epoch=3
    ):
        self.model = model
        self.train_dataset = train_dataset
        self.eval_dataset = eval_dataset

        from transformers import DataCollatorWithPadding
        self.data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

        self.args = self.get_args(output_dir, epoch)

        self.trainer = Trainer(
            model=self.model,
            args=self.args,
            train_dataset=self.train_dataset,
            eval_dataset=self.eval_dataset,
            data_collator=self.data_collator,
            # compute_metrics=compute_metrics,
            tokenizer=tokenizer,
        )

    def get_args(self, output_dir, epoch):
        if self.eval_dataset:
            args = TrainingArguments(
                output_dir=output_dir,
                evaluation_strategy="epoch",
                save_strategy="epoch",
                save_total_limit=3,
                learning_rate=2e-5,
                num_train_epochs=epoch,
                weight_decay=0.01,
                per_device_train_batch_size=32,
                per_device_eval_batch_size=16,
                # logging_steps=16,
                save_safetensors=True,
                overwrite_output_dir=True,
                load_best_model_at_end=True,
            )
        else:
            args = TrainingArguments(
                output_dir=output_dir,
                evaluation_strategy="no",
                save_strategy="epoch",
                save_total_limit=3,
                learning_rate=2e-5,
                num_train_epochs=epoch,
                weight_decay=0.01,
                per_device_train_batch_size=32,
                per_device_eval_batch_size=16,
                # logging_steps=16,
                save_safetensors=True,
                overwrite_output_dir=True,
                # load_best_model_at_end=True,
            )
        return args

    def set_args(self, args):
        """
        从外部重新设置 TrainingArguments,args 更新后,trainer也进行更新
        """
        self.args = args

        self.trainer = Trainer(
            model=self.model,
            args=self.args,
            train_dataset=self.train_dataset,
            eval_dataset=self.eval_dataset,
            data_collator=self.data_collator,
            # compute_metrics=compute_metrics,
            tokenizer=tokenizer,
        )

    def train(self, over_write=False):
        best_model_path = os.path.join(self.args.output_dir, "best_model")

        if over_write:
            self.trainer.train()
            self.trainer.save_model()
        elif not os.path.exists(best_model_path):
            self.trainer.train()
            self.trainer.save_model()
        else:
            print(
                f"预训练权重 {best_model_path} 已存在,且over_write={over_write}。不启动模型训练!"
            )

    def eval(self, eval_dataset):
        predictions = self.trainer.predict(eval_dataset)
        preds = np.argmax(predictions.predictions, axis=-1)
        metric = evaluate.load("glue", "mrpc")
        return metric.compute(predictions=preds, references=predictions.label_ids)

    def pred(self, pred_dataset):
        predictions = self.trainer.predict(pred_dataset)
        preds = np.argmax(predictions.predictions, axis=-1)
        return pred_dataset.add_column("pred", preds)

关于该工具类的使用可以浏览下述文章:https://github.com/JieShenAI/csdn/blob/main/24/09/bert_cls/bert.ipynb

总结

总体上看,本文做了一下数据集的处理,大模型的微调过程、模型权重报错、日志记录,这些过程全部由 transformers 的 Trainer 自动进行。

用好 框架, 事半功倍。当然前提是已经掌握了基础的手动参数微调。

参考资料

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

jieshenai

为了遇见更好的文章

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值