自然语言处理NLP:使用DisBert模型完成文本二分类任务(Pytorch)

项目来源于kaggle竞赛,地址为:Natural Language Processing with Disaster Tweets | Kaggle

        本文主要是对本人学习NLP的过程做个总结和记录,以方便日后复习,当然如果本文能帮助到阅读该文的读者,我会感到很开心。

        该项目通过建立一个机器学习模型,预测哪些推文是关于真实灾难的,哪些不是。

        首先读取数据,看一下数据的样子,数据在kaggle上可直接下载。可以看到有用的是text列和target列,我们只需要这两列进行模型训练即可,target为1表示text是真实灾难相关的文本,反之则不是。

import pandas as pd

train_df = pd.read_csv('/kaggle/input/nlpdataset/train.csv')
test_df = pd.read_csv('/kaggle/input/nlpdataset/test.csv')
train_df

         然后定义缩写字符变扩写字符的字典,以便将文本中的缩写字符扩展为全称,在NLP预处理数据过程中,将英文缩写扩展为全称的作用是为了将缩写与其全称等效地表示,以便于算法处理。例如,将"I'm" 扩展为"I am",将"don't" 扩展为"do not",将"it's" 扩展为"it is" 或 "it has"。

        缩写扩展的作用主要是为了消除歧义:同一个缩写可能有多种可能的解释,例如"AI"可以表示"Artificial Intelligence"或者"Air India"。将缩写扩展为全称可以消除这种歧义,确保模型可以正确地理解数据。

        日后如果有文本缩写扩展处理的需要,可直接复制该字典,然后进行替换操作。

# 缩写变扩写
contractions = {
    "ain't": "am not",
    "aren't": "are not",
    "can't": "cannot",
    "can't've": "cannot have",
    "'cause": "because",
    "could've": "could have",
    "couldn't": "could not",
    "couldn't've": "could not have",
    "didn't": "did not",
    "doesn't": "does not",
    "don't": "do not",
    "hadn't": "had not",
    "hadn't've": "had not have",
    "hasn't": "has not",
    "haven't": "have not",
    "he'd": "he would",
    "he'd've": "he would have",
    "he'll": "he will",
    "he'll've": "he will have",
    "he's": "he is",
    "how'd": "how did",
    "how'd'y": "how do you",
    "how'll": "how will",
    "how's": "how is",
    "i'd": "I would",
    "i'd've": "I would have",
    "i'll": "I will",
    "i'll've": "I will have",
    "i'm": "I am",
    "i've": "I have",
    "isn't": "is not",
    "it'd": "it would",
    "it'd've": "it would have",
    "it'll": "it will",
    "it'll've": "it will have",
    "it's": "it is",
    "let's": "let us",
    "ma'am": "madam",
    "mayn't": "may not",
    "might've": "might have",
    "mightn't": "might not",
    "mightn't've": "might not have",
    "must've": "must have",
    "mustn't": "must not",
    "mustn't've": "must not have",
    "needn't": "need not",
    "needn't've": "need not have",
    "o'clock": "of the clock",
    "oughtn't": "ought not",
    "oughtn't've": "ought not have",
    "shan't": "shall not",
    "sha'n't": "shall not",
    "shan't've": "shall not have",
    "she'd": "she would",
    "she'd've": "she would have",
    "she'll": "she will",
    "she'll've": "she will have",
    "she's": "she is",
    "should've": "should have",
    "shouldn't": "should not",
    "shouldn't've": "should not have",
    "so've": "so have",
    "so's": "so is",
    "that'd": "that would",
    "that'd've": "that would have",
    "that's": "that is",
    "there'd": "there would",
    "there's": "there is",
    "they'd": "they would",
    "they'd've": "they would have",
    "they'll": "they will",
    "they'll've": "they will have",
    "they're": "they are",
    "they've": "they have",
    "to've": "to have",
    "wasn't": "was not",
    "we'd": "we would",
    "we'd've": "we would have",
    "we'll": "we will",
    "we'll've": "we will have",
    "we're": "we are",
    "we've": "we have",
    "weren't": "were not",
    "what'll": "what will",
    "what'll've": "what will have",
    "what're": "what are",
    "what's": "what is",
    "what've": "what have",
    "when's": "when is",
    "when've": "when have",
    "where'd": "where did",
    "where's": "where is",
    "where've": "where have",
    "who'll": "who will",
    "who'll've": "who will have",
    "who's": "who is",
    "who've": "who have",
    "why's": "why is",
    "why've": "why have",
    "will've": "will have",
    "won't": "will not",
    "won't've": "will not have",
    "would've": "would have",
    "wouldn't": "would not",
    "wouldn't've": "would not have",
    "y'all": "you all",
    "y'all'd": "you all would",
    "y'all'd've": "you all would have",
    "y'all're": "you all are",
    "y'all've": "you all have",
    "you'd": "you would",
    "you'd've": "you would have",
    "you'll": "you will",
    "you'll've": "you will have",
    "you're": "you are",
    "you've": "you have"
}

country_contractions = {
    'u.s': 'united states',
    'u.s.': 'united states',
    'u.s.a': 'united states',
    'u.k': 'united kingdom',
    'u.k.': 'united kingdom',
    'u.a.e': 'united arab emirates',
    'u.a.e.': 'united arab emirates',
    's.korea': 'south korea',
    'n.korea': 'north korea',
    'czech rep.': 'czech republic',
    'dominican rep.': 'dominican republic',
    'costa rica': 'republic of costa rica',
    'el salvador': 'republic of el salvador',
    'guinea-bissau': 'republic of guinea-bissau',
    'cote d\'ivoire': 'republic of cote d\'ivoire',
    'trinidad & tobago': 'republic of trinidad and tobago',
    'congo-brazzaville': 'republic of the congo',
    'congo-kinshasa': 'democratic republic of the congo',
    'sri lanka': 'democratic socialist republic of sri lanka',
    'central african rep.': 'central african republic',
    'san marino': 'republic of san marino',
    'são tomé & príncipe': 'democratic republic of são tomé and príncipe',
    'timor-leste': 'democratic republic of timor-leste'
}

        定义缩写扩展的函数,需要注意的地方在注释中已经写出。

# 缩写变扩写
def expand_contractions(text,contractions):
    words = text.split()
    expand_contractions = []
    for word in words:
        if word.lower() in contractions:#判断word是否在contractions字典的键中出现
            expand_contraction = contractions[word.lower()]
            expand_contractions.append(expand_contraction)
        else:
            expand_contractions.append(word)
    return ' '.join(expand_contractions)# join函数用于将列表中的字符连接为一个字符串(以空格分割)        

        进行文本清洗,包括缩写字符扩展,以及将URL、数字和标点替换为空格。URL、标点、数字等在自然语言文本中属于噪声,将它们替换为空格可以去除这些噪声,以便模型可以更好的学习到有用的特征。

import re

# 应用expand_contractions并用sub函数进行url、空格、标点、数字字符替换为空格
def clean_text(text):
    # 替换一般缩略字符和城市缩略字符
    text = expand_contractions(text, contractions)
    text = expand_contractions(text, country_contractions)
    
    text = re.sub(r'http\S+|https\S+|www\S+', '', text, flags = re.MULTILINE)
    # \S代表匹配任意非空字符
    # 其中 \S 表示匹配非空白字符,+ 表示匹配一个或多个。
    # \S+ 表示匹配一个或多个非空白字符。| 表示或的关系。
    # \S+ 和 | 组合在一起,表示匹配以 http、www 或 https 开头的字符串。
    # flags=re.MULTILINE这个参数是多余的,可以不加
    
    text = re.sub(r'\W', ' ', text)# 非字母数字字符 空格,标点符号等
    text = re.sub(r'\d', ' ', text)# 数字字符
    text = re.sub(r'\s+', ' ', text).strip()# 匹配一个或多个空格字符,替换为空格,strip函数删除开头和结尾的空格
    
    return text
    
train_df['clean_text'] = train_df['text'].apply(clean_text)
test_df['clean_text'] = test_df['text'].apply(clean_text)
train_df

         分割训练集和验证集,注意原始数据集中本就存在测试集。

import math

split_count = math.floor(len(train_df) / 10 * 9)
val_df = train_df[split_count:]
train_df = train_df[:split_count]

        之后通过Transfoemers库定义DistilBert模型的tokenizer和model,这里使用的是DistilBert模型,它是BERT模型的一种轻量级版本,通过压缩BERT模型的大小和计算量以及进行知识蒸馏,使得DistilBERT在保持高性能的同时,具有更小的模型体积和更快的推理速度。

  1. 知识蒸馏:是一种模型压缩的技术,用于将一个较大和复杂的模型的知识传递给一个较小和 单的模型中。具体来说,知识蒸馏将一个“教师”模型(通常为较大、较复杂的模型)的知识“蒸馏”到一个“学生”模型(通常为较小、较简单的模型)中,使得学生模型能够学习到教师模型中的关键特征和知识,从而获得更好的泛化能力和性能。

        Transformers是由Hugging Face公司开发的一个自然语言处理(NLP)库,它提供了一系列基于深度学习的预训练模型,包括BERT、GPT、RoBERTa、DistilBERT等,方便用户进行模型微调以应用到不同的任务中,Transformers库的应用场景广泛,可以用于各种NLP任务,如情感分析、机器翻译、文本分类、问答系统等。

  1. 预训练模型:在NLP中,预训练模型通常使用大量文本数据进行无监督式预训练,以学习语言的特征和结构,例如语法、语义和上下文等。
  2. 微调:微调是指在已经预训练的模型基础上,进一步训练模型以适应特定任务的过程。微调通常使用有监督学习方法,使用任务特定的数据集进行训练,以调整预训练模型的权重和参数,使其适应特定的NLP任务。

        如果想更进一步了解和学习Transformers库,这里给出pytorch官网关于Transformers库的链接:PyTorch-Transformers | PyTorch

from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
import torch

MODEL_NAME = 'distilbert-base-uncased'
tokenizer = DistilBertTokenizer.from_pretrained(MODEL_NAME)
model = DistilBertForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=2)

        之后将数据集进行格式转换,最终转换为模型输入所需要的torch格式。tokenizer函数的作用是将输入的文本数据进行分词并进行必要的处理,padding和truncation参数用于控制padding和截断的行为。padding参数用于将所有输入的文本数据padding到相同的长度,以便进行批量处理。如果一个样本的长度小于指定的length ,则用pad_token_id进行填充,如果长度大于length ,则根据truncation参数的设置进行截断。

from transformers import TrainingArguments, Trainer

def tokenize(batch):
    return tokenizer(batch['text'], padding=True, truncation=True)

test_df["target"] = 0

#取出文本和标签
train_dataset = train_df[['clean_text', 'target']]
val_dataset = val_df[['clean_text', 'target']]
test_dataset = test_df[['clean_text', 'target']]

#更改列名
train_dataset = train_dataset.rename(columns={'clean_text': 'text', 'target': 'label'})
val_dataset = val_dataset.rename(columns={'clean_text': 'text', 'target': 'label'})
test_dataset = test_dataset.rename(columns={'clean_text': 'text', 'target': 'label'})

#records参数代表将每一行转换为一个字典,并存储在一个列表中
train_dataset = train_dataset.to_dict('records')
val_dataset = val_dataset.to_dict('records')
test_dataset = test_dataset.to_dict('records')

from datasets import Dataset

#其实可以直接将train_dataset转换为Dataset,不需要进行字典转换,多此一举了
train_dataset = Dataset.from_pandas(pd.DataFrame(train_dataset))
val_dataset = Dataset.from_pandas(pd.DataFrame(val_dataset))
test_dataset = Dataset.from_pandas(pd.DataFrame(test_dataset))

# 将batch的大小设置为整个数据集的大小,意味着将整个数据集进行tokenize操作
train_dataset = train_dataset.map(tokenize, batched=True, batch_size=len(train_dataset))
val_dataset = val_dataset.map(tokenize, batched=True, batch_size=len(val_dataset))
test_dataset = test_dataset.map(tokenize, batched=True, batch_size=len(test_dataset))

# 将数据集转换为torch格式
train_dataset.set_format('torch', columns=['input_ids', 'attention_mask', 'label'])
val_dataset.set_format('torch', columns=['input_ids', 'attention_mask', 'label'])
test_dataset.set_format('torch', columns=['input_ids', 'attention_mask', 'label'])

         设置训练参数和训练器,训练并保存模型。

import os 

os.environ["WANDB_DISABLED"] = "true"

# 设置训练参数
training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=10,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    warmup_steps=500,
    weight_decay=0.01,
    logging_dir='./logs',
    logging_steps=10,
    load_best_model_at_end=True,
    evaluation_strategy="epoch",
    save_strategy="epoch"
)

# 初始化训练器
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=tokenizer,
)

# 训练模型
trainer.train()

# 保存模型参数
model.save_pretrained("./model")

         加载模型并在测试集上进行预测。

​# 加载模型
loaded_model = DistilBertForSequenceClassification.from_pretrained("./model")
​
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def predict(text, loaded_model, tokenizer):
    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
    inputs.to(device)
    outputs = loaded_model(**inputs)
    logits = outputs.logits
    probabilities = torch.softmax(logits, dim=-1)
    predictions = torch.argmax(probabilities, dim=-1)
    return predictions.item()

test_df["predicted_target"] = test_df["clean_text"].apply(lambda x: predict(x, model, tokenizer))

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值