huggingface 笔记 finetune模型

1 数据处理

1.1 加载数据集

from datasets import load_dataset

dataset = load_dataset("yelp_review_full")

dataset
'''

DatasetDict({
    train: Dataset({
        features: ['label', 'text'],
        num_rows: 650000
    })
    test: Dataset({
        features: ['label', 'text'],
        num_rows: 50000
    })
})
'''
dataset['train'][0]
'''
{'label': 4,
 'text': "dr. goldberg offers everything i look for in a general practitioner.  he's nice and easy to talk to without being patronizing; he's always on time in seeing his patients; he's affiliated with a top-notch hospital (nyu) which my parents have explained to me is very important in case something happens and you need surgery; and you can get referrals to see specialists without having to see him first.  really, what more do you need?  i'm sitting here trying to think of any complaints i have about him, but i'm really drawing a blank."}
'''

1.2 分词

  • 需要一个分词器来处理文本,并包含一个填充和截断策略来处理任何可变的序列长度。
  • 使用Datasets的map方法应用一个预处理函数到整个数据集:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-cased")


def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True)


tokenized_datasets = dataset.map(tokenize_function, batched=True)

tokenized_datasets
'''
DatasetDict({
    train: Dataset({
        features: ['label', 'text', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 650000
    })
    test: Dataset({
        features: ['label', 'text', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 50000
    })
})
'''
tokenized_datasets['train'][0]
'''
{'label': 4,
 'text': "dr. goldberg offers everything i look for in a general practitioner.  he's nice and easy to talk to without being patronizing; he's always on time in seeing his patients; he's affiliated with a top-notch hospital (nyu) which my parents have explained to me is very important in case something happens and you need surgery; and you can get referrals to see specialists without having to see him first.  really, what more do you need?  i'm sitting here trying to think of any complaints i have about him, but i'm really drawing a blank.",
 'input_ids': [101,
  173,
  1197,
  119,
  2284,
  2953,
  3272,
  1917,
  178,
  1440,
  1111,
  1107,
  170,
  1704,
  22351,
  119,
  1119,
  112,
  188,
  3505,
  1105,
  3123,
  1106,
  2037,
  1106,
  1443,
  1217,
  10063,
  4404,
  132,
  1119,
  112,
  188,
  1579,
  1113,
  1159,
  1107,
  3195,
  1117,
  4420,
  132,
  1119,
  112,
  188,
  6559,
  1114,
  170,
  1499,
  118,
  23555,
  2704,
  113,
  183,
  9379,
  114,
  1134,
  1139,
  2153,
  1138,
  3716,
  1106,
  1143,
  1110,
  1304,
  1696,
  1107,
  1692,
  1380,
  5940,
  1105,
  1128,
  1444,
  6059,
  132,
  1105,
  1128,
  1169,
  1243,
  5991,
  16179,
  1106,
  1267,
  18137,
  1443,
  1515,
  1106,
  1267,
  1140,
  1148,
  119,
  1541,
  117,
  1184,
  1167,
  1202,
  1128,
  1444,
  136,
  178,
  112,
  182,
  2807,
  1303,
  1774,
  1106,
  1341,
  1104,
  1251,
  11344,
  178,
  1138,
  1164,
  1140,
  117,
  1133,
  178,
  112,
  182,
  1541,
  4619,
  170,
  9153,
  119,
  102,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
...
'''

1.3 创建子数据集

创建数据集的一个较小子集来进行微调,以减少所需的时间

small_train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(1000))
small_eval_dataset = tokenized_datasets["test"].shuffle(seed=42).select(range(1000))

small_train_dataset
'''

Dataset({
    features: ['label', 'text', 'input_ids', 'token_type_ids', 'attention_mask'],
    num_rows: 1000
})
'''

2  使用Transformers提供的Trainer微调

  • Transformers 提供了一个为训练Transformers 模型优化的 Trainer 类

2.1 加载模型

  • 首先加载模型并指定预期的标签数量
from transformers import AutoModelForSequenceClassification

model=AutoModelForSequenceClassification.from_pretrained("google-bert/bert-base-cased", 
                                                        num_labels=5)

2.2 TrainingArguments 类

  • 创建一个 TrainingArguments 类
    • 包含可以调整的所有超参数以及激活不同训练选项的标志
    • 指定保存训练检查点checkpoint的位置
from transformers import TrainingArguments

training_args=TrainingArguments(output_dir='test_trainier')

2.2.1  在微调期间监控评估指标

如果希望在微调期间监控评估指标——>在训练参数中指定 evaluation_strategy 参数,以在每个epoch结束时报告评估指标:

from transformers import TrainingArguments

training_args = TrainingArguments(output_dir="test_trainer", 
                                  evaluation_strategy="epoch")

2.3 评估

  • Trainer 在训练期间不会自动评估模型性能
  • 需要向 Trainer 传递一个函数来计算和报告指标。
  • Evaluate 库提供了一个简单的准确性函数,可以使用 evaluate.load函数加载它:
import numpy as np
import evaluate

metric = evaluate.load("accuracy")

然后调用 compute 方法计算预测的准确性。

在将预测传递给 compute 之前,需要将 logits 转换为预测

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

2.4 创建Trainer对象

创建一个 Trainer 对象,带有模型、训练参数、训练和测试数据集以及评估函数:

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=small_train_dataset,
    eval_dataset=small_eval_dataset,
    compute_metrics=compute_metrics,
)

 【不用设计优化器欸】

2.5 开始微调

trainer.train()

3 使用传统的pytorch

首先数据集需要使用pytorch需要的DataLoader
 

from torch.utils.data import DataLoader

train_dataloader = DataLoader(small_train_dataset, shuffle=True, batch_size=8)
eval_dataloader = DataLoader(small_eval_dataset, batch_size=8)

模型还是一样的加载

声明优化器

from torch.optim import AdamW

optimizer = AdamW(model.parameters(), lr=5e-5)

然后就是pytorch训练“三件套”

for epoch in range(num_epochs):
    for batch in train_dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        loss = outputs.loss
        loss.backward()

        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

UQI-LIUWJ

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值