前言:
总结代码:使用 PyTorch 训练和加载 BERT 模型的完整示例。
训练模型的代码:
import torch
from transformers import BertTokenizer, BertForSequenceClassification
from torch.utils.data import TensorDataset, DataLoader
# 加载数据集
train_texts = [...] # 训练集文本
train_labels = [...] # 训练集标签
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
train_input_ids = []
train_attention_masks = []
for text in train_texts:
encoded = tokenizer.encode_plus(
text,
add_special_tokens=True,
max_length=128,
padding='max_length',
truncation=True,
return_attention_mask=True,
return_tensors='pt'
)
train_input_ids.append(encoded['input_ids'])
train_attention_masks.append(encoded['attention_mask'])
train_input_ids = torch.cat(train_input_ids, dim=0)
train_attention_masks = torch.cat(train_attention_masks, dim=0)
train_labels = torch.tensor(train_labels)
train_dataset = TensorDataset(train_input_ids, train_attention_masks, train_labels)
train_dataloader = DataLoader(train_dataset, batch_size=32)
# 定义模型和优化器
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
# 训练模型
model.train()
for epoch in range(5):
total_loss = 0
for step, batch in enumerate(train_dataloader):
batch_input_ids = batch[0]
batch_attention_masks = batch[1]
batch_labels = batch[2]
optimizer.zero_grad()
outputs = model(batch_input_ids, attention_mask=batch_attention_masks, labels=batch_labels)
loss = outputs.loss
total_loss += loss.item()
loss.backward()
optimizer.step()
avg_loss = total_loss / len(train_dataloader)
print(f'Epoch {epoch + 1}, average loss: {avg_loss:.4f}')
# 保存模型
torch.save(model.state_dict(), 'bert_model.pth')
这段代码中,首先使用 BertTokenizer 对训练集文本进行编码,生成训练集的输入 ID 和注意力掩码。然后将输入 ID、注意力掩码和标签封装成一个 TensorDataset,并使用 DataLoader 生成训练数据的迭代器。
接下来,使用 BertForSequenceClassification 定义一个 BERT 分类器,并使用 AdamW 优化器进行训练。在每个 epoch 中,遍历所有训练数据,计算模型的损失并更新参数。
最后,使用 torch.save 保存训练好的模型参数到文件中。
加载模型的代码:
import torch
from transformers import BertForSequenceClassification, BertTokenizer
# 加载模型和词汇表
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
# 加载模型参数
model.load_state_dict(torch.load('bert_model.pth'))
# 使用模型进行推理
text = 'this is a test sentence'
encoded = tokenizer.encode_plus(
text,
add_special_tokens=True,
max_length=128,
padding
欢迎点赞,收藏!!