BERT模型微调的基本步骤(demo)

对于预训练模型的微调,一个常见的例子是使用BERT模型进行情感分析任务。以下是一个使用Python和Transformers库进行BERT模型微调的基本步骤:

首先,安装必要的库,包括transformers和torch:

pip install transformers torch

以下是微调代码的例子:

from transformers import BertTokenizer, BertForSequenceClassification
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torch.optim import AdamW
import torch

# 1. 加载预训练的tokenizer和模型
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForSequenceClassification.from_pretrained(
    "bert-base-uncased", # 使用12层的BERT模型
    num_labels = 2, # 二分类任务(比如情感分析)
    output_attentions = False, # 模型是否返回注意力权重
    output_hidden_states = False, # 模型是否返回所有隐藏状态
)

# 2. 准备数据
# 假设我们有一些文本数据和对应的标签
texts = ['I love this movie!', 'I hate this movie!']
labels = [1, 0]  # 1代表积极情绪,0代表消极情绪

# 使用tokenizer处理文本数据
inputs = tokenizer(texts, padding=True, truncation=True, max_length=512, return_tensors='pt')

# 把标签转换成Tensor
labels = torch.tensor(labels)

# 3. 创建一个DataLoader
data = list(zip(inputs['input_ids'], inputs['attention_mask'], labels))
dataloader = DataLoader(data, batch_size=2)

# 4. 微调模型
# 设置优化器
optimizer = AdamW(model.parameters(), lr=1e-5)

# 开始训练
model.train()
for epoch in range(3):  # 这里只做3个epoch的训练
    for batch in dataloader:
        input_ids, attention_mask, labels = batch
        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

# 5. 保存微调后的模型
model.save_pretrained('./my_model')

在这个例子中,我们首先加载了预训练的BERT模型和对应的tokenizer。然后,我们准备了一些文本数据和对应的标签,使用tokenizer处理文本数据,然后创建了一个DataLoader。接下来,我们设置了优化器,开始训练模型。最后,我们保存了微调后的模型。

请注意,这只是一个非常基础的例子,实际上在进行模型微调时,你可能需要处理更复杂的数据,选择合适的损失函数和优化器,以及进行模型性能的评估等等。你可以查看Transformers库的文档和示例来获取更多信息。

  • 1
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值