基于Robert的文本分类任务,在此基础上考虑融合对比学习、Prompt和对抗训练来提升模型的文本分类能力,我本地有SST-2数据集的train.txt、dev.txt两个文件,每个文件包含文本内容和标签两列,是个二分类任务,本项目基于pytorch实现。
先介绍一下要融合的三个技术。
1. 对比学习旨在通过对比相似和不相似的样本来提高分类模型的性能。对于每个样本,我们可以在训练时随机选取一个与其相似的样本,并加入到训练中,以鼓励模型更好地学习相似样本的特征,同时在训练时也要随机选取一个不相似的样本,并将其加入到训练中。这可以帮助模型更好地区分不同类别之间的特征。
2. Prompt是一种基于预设文本片段的模型输入方式。通过给定关键词和语法结构,Prompt可以引导模型学习某些具体任务。在文本分类任务中,我们可以给模型预设一些文本提示,以帮助模型更好地学习关键特征。
3. 对抗训练是一种在训练模型时加入干扰数据(扰动)的技术,以增强模型的鲁棒性。在文本分类任务中,我们可以通过向文本中添加词语或修改词语顺序,来生成干扰数据,从而帮助模型更好地区分和理解输入文本。
目录
一、安装依赖库
下面是具体实现的代码,我们将使用PyTorch框架:
首先安装必要的库:
!pip install transformers
!pip install torch
!pip install scikit-learn
然后我们导入需要的库以及设置随机种子以保证实验可重复性等必要组件:
import random
import numpy as np
import torch
from sklearn.metrics import accuracy_score, f1_score
from transformers import RobertaTokenizer, RobertaForSequenceClassification, AdamW
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler
from transformers import get_linear_schedule_with_warmup
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)
torch.backends.cudnn.deterministic = True
二、载数据集并进行数据预处理
class TextDataset(Dataset):
def __init__(self, tokenizer, path, max_length):
self.tokenizer = tokenizer
self.max_length = max_length
self.labels = []
self.texts = []
with open(path) as f:
for line in f:
line = line.strip().split('\t')
text, label = line[0], int(line[1])
self.labels.append(label)
self.texts.append(text)
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
text, label = self.texts[idx], self.labels[idx]
encoding = self.tokenizer(text, truncation=True, padding='max_length', max_length=self.max_length,
return_tensors='pt')
return dict(
text=text,
input_ids=encoding['input_ids'].squeeze(),
attention_mask=encoding['attention_mask'].squeeze(),
labels=torch.tensor(label)
)
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
train_dataset = TextDataset(tokenizer, 'train.txt', 256)
dev_dataset = TextDataset(tokenizer, 'dev.txt', 256)
train_sampler = RandomSampler(train_dataset)
dev_sampler = SequentialSampler(dev_dataset)
train_loader = DataLoader(train_dataset, sampler=train_sampler, batch_size=16)
dev_loader = DataLoader(dev_dataset, sampler=dev_sampler, batch_size=16)
三、定义模型并训练模型
model = RobertaForSequenceClassification.from_pretrained('roberta-base', num_labels=2).to(device)
optimizer = AdamW(model.parameters(), lr=2e-5, weight_decay=0.01)
# We will use a linear decay scheduler
total_steps = len(train_loader) * 5
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)
for epoch in range(5):
model.train()
for batch in train_loader:
batch = {k: v.to(device) for k, v in batch.items()}
optimizer.zero_grad()
outputs = model(**batch)
loss = outputs[0]
loss.backward()
optimizer.step()
scheduler.step()
model.eval()
with torch.no_grad():
targets, preds = [], []
for batch in dev_loader:
batch = {k: v.to(device) for k, v in batch.items()}
outputs = model(**batch)
targets.extend(batch['labels'].tolist())
preds.extend(torch.argmax(outputs.logits, axis=-1).tolist())
acc = accuracy_score(targets, preds)
f1 = f1_score(targets, preds)
print(f'\nEpoch {epoch + 1}:')
print(f'Dev Accuracy: {acc:.4f}')
print(f'Dev F1 Score: {f1:.4f}')
至此,我们已经成功地训练了一款基于RoBERTa模型的文本分类器。下面是加入融合技术的实现。
四、对比学习实现
def random_similar_text(texts, labels):
res_texts, res_labels = [], []
for idx, text in enumerate(texts):
res_texts.append(text)
res_labels.append(labels[idx])
# 随机选择一个与当前样本相似的样本,将它加入到数据集中
rand_idx = np.random.choice(len(texts), 1)[0]
res_texts.append(texts[rand_idx])
res_labels.append(labels[rand_idx])
# 随机选择一个不相似的样本,将它加入到数据集中
rand_idx = np.random.choice(len(texts), 1)[0]
while rand_idx == idx:
rand_idx = np.random.choice(len(texts), 1)[0]
res_texts.append(texts[rand_idx])
res_labels.append(labels[rand_idx])
return res_texts, res_labels
train_texts, train_labels = random_similar_text(train_dataset.texts, train_dataset.labels)
train_dataset = TextDataset(tokenizer, 'train.txt', 256)
五、Prompt实现
def add_prompt(prompt, texts):
return [f'{prompt}{text}' for text in texts]
train_dataset.texts = add_prompt('This text is', train_dataset.texts)
dev_dataset.texts = add_prompt('This text is', dev_dataset.texts)
六、对抗训练实现
def add_perturbations(text, n):
# 随机选择n个词,并在其周围添加一些噪声生成n个干扰文本
words = text.split()
idx_list = np.random.choice(len(words), n, replace=False)
for idx in idx_list:
words[idx] = f'[{words[idx]}]'
return ' '.join(words)
def generate_perturbations(texts):
return [add_perturbations(text, 3) for text in texts]
train_dataset.texts += generate_perturbations(train_dataset.texts)
dev_dataset.texts += generate_perturbations(dev_dataset.texts)
七、整个过程封装成一个函数
def train_roberta_with_fusion(train_path, dev_path, num_classes, fusion_type):
def random_similar_text(texts, labels):
res_texts, res_labels = [], []
for idx, text in enumerate(texts):
res_texts.append(text)
res_labels.append(labels[idx])
rand_idx = np.random.choice(len(texts), 1)[0]
res_texts.append(texts[rand_idx])
res_labels.append(labels[rand_idx])
rand_idx = np.random.choice(len(texts), 1)[0]
while rand_idx == idx:
rand_idx = np.random.choice(len(texts), 1)[0]
res_texts.append(texts[rand_idx])
res_labels.append(labels[rand_idx])
return res_texts, res_labels
def add_perturbations(text, n):
words = text.split()
idx_list = np.random.choice(len(words), n, replace=False)
for idx in idx_list:
words[idx] = f'[{words[idx]}]'
return ' '.join(words)
def add_prompt(prompt, texts):
return [f'{prompt}{text}' for text in texts]
def generate_perturbations(texts):
return [add_perturbations(text, 3) for text in texts]
class TextDataset(Dataset):
def __init__(self, tokenizer, path, max_length):
self.tokenizer = tokenizer
self.max_length = max_length
self.labels = []
self.texts = []
with open(path) as f:
for line in f:
line = line.strip().split('\t')
text, label = line[0], int(line[1])
self.labels.append(label)
self.texts.append(text)
if fusion_type == 'contrastive':
self.texts, self.labels = random_similar_text(self.texts, self.labels)
if fusion_type == 'adversarial':
self.texts += generate_perturbations(self.texts)
if fusion_type == 'prompt':
self.texts = add_prompt('This text is', self.texts)
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
text, label = self.texts[idx], self.labels[idx]
encoding = self.tokenizer(text, truncation=True, padding='max_length', max_length=self.max_length,
return_tensors='pt')
return dict(
text=text,
input_ids=encoding['input_ids'].squeeze(),
attention_mask=encoding['attention_mask'].squeeze(),
labels=torch.tensor(label)
)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
train_dataset = TextDataset(tokenizer, train_path, 256)
dev_dataset = TextDataset(tokenizer, dev_path, 256)
train_sampler = RandomSampler(train_dataset)
dev_sampler = SequentialSampler(dev_dataset)
train_loader = DataLoader(train_dataset, sampler=train