基于BERT模型的文本分类示例

7 篇文章 0 订阅
5 篇文章 0 订阅

1. 导入必要的库

import torch
from torch.utils.data import DataLoader, Dataset
from transformers import BertTokenizer, BertForSequenceClassification, AdamW
from transformers import get_linear_schedule_with_warmup
import os

2. 设置设备

根据GPU是否可用,设置默认的设备。

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

3. 定义数据集类

创建一个继承自Dataset的类,用于封装文本数据的处理。

class SentenceDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_len=512):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text_v = self.texts[idx]
        label = self.labels[idx]
        encoding = self.tokenizer.encode_plus(
            text_v,
            add_special_tokens=True,
            max_length=self.max_len,
            return_token_type_ids=False,
            padding='max_length',
            return_attention_mask=True,
            return_tensors='pt',
            truncation=True
        )
        return {
            'text': text_v,
            'input_ids': encoding['input_ids'].squeeze(),
            'attention_mask': encoding['attention_mask'].squeeze(),
            'labels': torch.tensor(label, dtype=torch.long)
        }

4. 读取数据文件的函数

定义一个函数来读取文本数据文件,并将其解析为文本和标签列表。

def read_dataset(file_path):
    texts, labels = [], []
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            text, label = line.strip().split('\t')  # 假设数据是制表符分隔
            texts.append(text)
            labels.append(int(label))  # 将标签转换为整数
    return texts, labels

5. 准备数据文件路径

设置训练数据和验证数据的文件路径。

train_file_path = './datasets/train_sentence_dataset.txt'
val_file_path = './datasets/test_sentence_dataset.txt'

6. 读取训练数据和验证数据

使用read_dataset函数读取训练集和验证集的数据。

train_texts, train_labels = read_dataset(train_file_path)
val_texts, val_labels = read_dataset(val_file_path)

7. 加载BERT分词器和模型

加载预训练的BERT分词器和模型。

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)

8. 创建数据集实例

使用文本数据和标签创建数据集的实例。

train_dataset = SentenceDataset(train_texts, train_labels, tokenizer)
val_dataset = SentenceDataset(val_texts, val_labels, tokenizer)

9. 准备数据加载器

创建训练和验证的数据加载器。

train_dataloader = DataLoader(train_dataset, batch_size=16)
val_dataloader = DataLoader(val_dataset, batch_size=16)

10. 定义优化器、调度器和损失函数

设置优化器、调度器和损失函数。

optimizer = AdamW(model.parameters(), lr=2e-5)
epochs_num = 100  # 训练的epoch数量
num_training_steps = len(train_dataloader) * epochs_num
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=num_training_steps)
loss_fn = torch.nn.CrossEntropyLoss()

11. 检查输出目录是否存在

检查模型保存目录是否存在,如果不存在则创建。

output_dir = './saved_model/'
if not os.path.exists(output_dir):
    os.makedirs(output_dir)
    print(f"目录 {output_dir} 已创建。")
else:
    print(f"目录 {output_dir} 已存在。")

12. 初始化最佳损失和最佳准确率

初始化最佳损失和最佳准确率。

best_loss = float('inf')
best_accuracy = 0.0

13. 训练模型

训练模型的代码,包括前向传播、损失计算、反向传播和参数更新。

model.train()
for epoch in range(epochs_num):
    epoch_loss = 0
    for batch in train_dataloader:
        batch = {k: v.to(device) for k, v in batch.items() if k != 'text'}
        
        outputs = model(**batch)
        loss = loss_fn(outputs.logits, batch['labels'])
        loss.backward()
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()
        epoch_loss += loss.item()

    # 进行验证
    model.eval()
    epoch_val_loss = 0
    correct_predictions = 0
    total_predictions = 0
    with torch.no_grad():
        for batch in val_dataloader:
            batch = {k: v.to(device) for k, v in batch.items()if k != 'text'}
            outputs = model(**batch)
            val_loss = loss_fn(outputs.logits, batch['labels'])
            epoch_val_loss += val_loss.item()
            
            _, predicted_labels = torch.max(outputs.logits, 1)
            total_predictions += batch['labels'].size(0)
            correct_predictions += (predicted_labels == batch['labels']).sum().item()

    epoch_val_loss /= len(val_dataloader)
    val_accuracy = correct_predictions / total_predictions

    print(f"Epoch {epoch + 1}/{epochs_num}, Train Loss: {epoch_loss / len(train_dataloader)}, "
          f"Validation Loss: {epoch_val_loss}, Validation Accuracy: {val_accuracy}")

    # 检查是否是最佳模型
    if val_accuracy > best_accuracy:
        best_accuracy = val_accuracy
        best_model_path = os.path.join(output_dir, 'best_text2cls.pth')
        torch.save(model.state_dict(), best_model_path)
        print(f"找到最佳模型,验证准确率为 {best_accuracy:.4f},在Epoch {epoch + 1},模型已保存。")

14. 保存最终模型

训练完成后,保存最终模型。

final_model_path = os.path.join(output_dir, 'final_model.pth')
torch.save(model.state_dict(), final_model_path)
print("Training complete. Final model has been saved to", final_model_path)

完整代码

详细代码请参考: GitHub链接
欢迎Star和Fork项目

备注

水平有限,有问题随时交流~


  • 2
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值