#!/usr/bin/env python
# -*- coding: utf-8 -*-
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertForSequenceClassification, AdamW, get_linear_schedule_with_warmup
import pandas as pd
# 示例数据(请替换为你的数据)
train_texts = ["我喜欢这个电影", "我讨厌这个电影", "我不在乎这个电影"]
train_labels = [0, 1, 2]
# 将文本数据转换为数据集类
class SentimentDataset(Dataset):
def __init__(self, texts, labels, tokenizer, max_length):
self.texts = texts
self.labels = labels
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self):
return len(self.texts)
def __getitem__(self, idx):
text = self.texts[idx]
label = self.labels[idx]
inputs = self.tokenizer.encode_plus(text, padding="max_length", max_length=self.max_length, truncation=True, return_tensors="pt")
inputs["label"] = torch.tensor(label)
return inputs
# 参数设置
num_labels = 3
tokenizer = BertTokenizer.from_pretrained("bert-base-chinese")
max_length = 128
batch_size = 16
epochs = 2
# 创建数据集和数据加载器
train_dataset = SentimentDataset(train_texts, train_labels, tokenizer, max_length)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
# 加载 BERT 模型并准备优化器和学习率调度器
model = BertForSequenceClassification.from_pretrained("bert-base-chinese", num_labels=num_labels)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
optimizer = AdamW(model.parameters(), lr=2e-5)
total_steps = len(train_dataloader) * epochs
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)
# 微调模型
for epoch in range(epochs):
model.train()
for batch in train_dataloader:
optimizer.zero_grad()
input_ids = batch["input_ids"].squeeze().to(device)
attention_mask = batch["attention_mask"].squeeze().to(device)
labels = batch["label"].to(device)
outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
loss = outputs.loss
loss.backward()
optimizer.step()
scheduler.step()
print("Epoch", epoch + 1, "completed")
# 保存微调后的模型
torch.save(model.state_dict(), "bert_sentiment_classifier.pt")
用bert微调做情感分类
于 2023-07-16 22:49:23 首次发布