NLP小样本之AG_NEWS 5-shot分类

  • 详细代码和介绍请联系微信:17324069443

NLP小样本之迁移学习来做5-shot分类是一种有效的方法,特别适用于在数据稀缺的情况下进行文本分类任务。在这种场景下,我们通常只有极少量的标记数据(5个样本)可用于训练模型,迁移学习则可以帮助我们利用其他相关数据源的知识来提升模型性能。
在这里插入图片描述

一、简述

  1. 首先,迁移学习通过将在大规模数据上预训练的模型(如BERT、GPT等)作为基础模型,然后在少量标记数据上进行微调,从而实现对新任务的学习。这种方法能够利用基础模型在大数据集上学到的通用特征,帮助模型更好地理解和泛化到新任务的特定领域。
  2. 在5-shot分类任务中,我们可以选择冻结基础模型的大部分参数,只微调最后几层网络以适应新任务。通过这种方式,模型可以更快地收敛并在少量数据上取得较好的表现。此外,还可以采用元学习(meta-learning)的方法,通过在多个小任务上学习以提高模型的泛化能力。
  3. 另外,跨语言迁移学习也是一种常见的策略,特别适用于多语言环境下的NLP任务。通过在一个语言上训练模型,然后将其应用到另一个语言的任务上,在5-shot分类中也可以发挥巨大作用。

二、项目介绍

2.2 数据集

AG_NEWS:新闻语料库,包含4个大类新闻:World、Sports、Business、Sci/Tec。
AG_NEWS共包含120000条训练样本集(train.csv), 7600测试样本数据集(test.csv)。每个类别分别拥有 30000 个训练样本及 1900 个测试样本。

2.1 BERT ( Bidirectional Encoder Representations from Transformers )

BERT是由 Google 在 2018 年提出的一种预训练语言模型。与传统的语言模型只能从左到右或者从右到左单向预测下一个词不同, BERT 使用了 Transformer 模型,并且在预训练阶段使用了双向的上下文信息。BERT 的预训练分为两个阶段: Masked Language Model ( MLM )和 Next Sentence Prediction( NSP )。在 MLM 阶段,BERT 在输入的句子中随机遮盖一些词汇,然后通过上下文信息来预测这些被遮盖的词汇。在 NSP 阶段, BERT 输入一对句子,并判断这两个句子是否是连续的。在预训练完成后, BERT 可以进行下游任务的微调,如文本分类、命名实体识别、自然语言推理等。通过微调, BERT 可以将其学习到的语言表示应用于各种自然语言处理任务中。BERT 的优点包括:

  • 预训练阶段使用了双向上下文信息,有助于更好地理解句子中的语义和语境。
  • BERT 可以用于各种下游任务,只需要微调即可,无需针对每个任务重新训练模型。
  • BERT 在许多自然语言处理任务上取得了优异的性能,甚至超过了以往的模型。

三、代码实现

5-shot分类,顾名思义就是用每类数据集中的5个样本进行训练,剩下的样本进行测试

import numpy as np
import pandas as pd
import sklearn
from simpletransformers.classification import ClassificationModel

def train(train_file, test_file):
    return
    # Reading the train and test files
    train_df = pd.read_csv(train_file)
    test_df = pd.read_csv(test_file)
    # torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113
    # using PyTorch and BERT and hence we are not doing any text preprocessing
    df = pd.DataFrame()
    df['text'] = train_df['Description']
    df['label'] = train_df['Class Index']
    # Since the labels are starting from 1 to N - we need to map it to 0 to N-1
    df['label'] = df['label'].apply(lambda x : x -1)
    # 5 represents the number of samples extracted from each class for training
    small_df = df.groupby('label').apply(lambda x: x.sample(5, replace=True)).reset_index(drop=True)
    
    # Configure the simple transformer for classificating the text
    # select the bert model you want to train
    model = ClassificationModel('bert', 'bert-base-cased', num_labels=4, 
        args={'reprocess_input_data':True, 'overwrite_output_dir':True, 'num_train_epochs':80, 'learning_rate':5e-5,
              'train_batch_size':20, 'eval_batch_size':20 }, 
        use_cuda=False
    )
    # Lets begin our training of the model on a smaller dataset
    model.train_model(small_df)

    # Lets prepare our evalution data set
    dt = pd.DataFrame()
    dt['text'] = test_df['Description']
    dt['label'] = test_df['Class Index']
    dt['label'] = dt['label'].apply(lambda x : x -1)
    small_dt = dt.groupby('label').apply(lambda x: x.sample(5, replace=True)).reset_index(drop=True)
    
    # Evaluate the model
    result, model_outputs, wrong_predictions = model.eval_model(small_dt)
    predicted = []
    for arr in model_outputs:
        predicted.append(np.argmax(arr))
    true = small_dt['label'].tolist()
    print(sklearn.metrics.classification_report(true, predicted, target_names=['World','Sports','Business','Sci/Tech']))

if __name__=="__main__":
    train_file = "data/train.csv"
    test_file  = "data/test.csv"
    train(train_file, test_file)
  • 25
    点赞
  • 23
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
496,835 条来自 AG 新闻语料库 4 大类别超过 2000 个新闻源的新闻文章,数据集仅仅援用了标题和描述字段。每个类别分别拥有 30,000 个训练样本及 1900 个测试样本。 README: AG's News Topic Classification Dataset Version 3, Updated 09/09/2015 ORIGIN AG is a collection of more than 1 million news articles. News articles have been gathered from more than 2000 news sources by ComeToMyHead in more than 1 year of activity. ComeToMyHead is an academic news search engine which has been running since July, 2004. The dataset is provided by the academic comunity for research purposes in data mining (clustering, classification, etc), information retrieval (ranking, search, etc), xml, data compression, data streaming, and any other non-commercial activity. For more information, please refer to the link http://www.di.unipi.it/~gulli/AG_corpus_of_news_articles.html . The AG's news topic classification dataset is constructed by Xiang Zhang (xiang.zhang@nyu.edu) from the dataset above. It is used as a text classification benchmark in the following paper: Xiang Zhang, Junbo Zhao, Yann LeCun. Character-level Convolutional Networks for Text Classification. Advances in Neural Information Processing Systems 28 (NIPS 2015). DESCRIPTION The AG's news topic classification dataset is constructed by choosing 4 largest classes from the original corpus. Each class contains 30,000 training samples and 1,900 testing samples. The total number of training samples is 120,000 and testing 7,600. The file classes.txt contains a list of classes corresponding to each label. The files train.csv and test.csv contain all the training samples as comma-sparated values. There are 3 columns in them, corresponding to class index (1 to 4), title and description. The title and description are escaped using double quotes ("), and any internal double quote is escaped by 2 double quotes (""). New lines are escaped by a backslash followed with an "n" character, that is "\n".
AG's News Topic Classification Dataset Version 3, Updated 09/09/2015 ORIGIN AG is a collection of more than 1 million news articles. News articles have been gathered from more than 2000 news sources by ComeToMyHead in more than 1 year of activity. ComeToMyHead is an academic news search engine which has been running since July, 2004. The dataset is provided by the academic comunity for research purposes in data mining (clustering, classification, etc), information retrieval (ranking, search, etc), xml, data compression, data streaming, and any other non-commercial activity. For more information, please refer to the link http://www.di.unipi.it/~gulli/AG_corpus_of_news_articles.html . The AG's news topic classification dataset is constructed by Xiang Zhang (xiang.zhang@nyu.edu) from the dataset above. It is used as a text classification benchmark in the following paper: Xiang Zhang, Junbo Zhao, Yann LeCun. Character-level Convolutional Networks for Text Classification. Advances in Neural Information Processing Systems 28 (NIPS 2015). DESCRIPTION The AG's news topic classification dataset is constructed by choosing 4 largest classes from the original corpus. Each class contains 30,000 training samples and 1,900 testing samples. The total number of training samples is 120,000 and testing 7,600. The file classes.txt contains a list of classes corresponding to each label. The files train.csv and test.csv contain all the training samples as comma-sparated values. There are 3 columns in them, corresponding to class index (1 to 4), title and description. The title and description are escaped using double quotes ("), and any internal double quote is escaped by 2 double quotes (""). New lines are escaped by a backslash followed with an "n" character, that is "\n".

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值