NLP小样本之AG_NEWS 5-shot分类

  1. 首先,迁移学习通过将在大规模数据上预训练的模型(如BERT、GPT等)作为基础模型,然后在少量标记数据上进行微调,从而实现对新任务的学习。这种方法能够利用基础模型在大数据集上学到的通用特征,帮助模型更好地理解和泛化到新任务的特定领域。
  2. 在5-shot分类任务中,我们可以选择冻结基础模型的大部分参数,只微调最后几层网络以适应新任务。通过这种方式,模型可以更快地收敛并在少量数据上取得较好的表现。此外,还可以采用元学习(meta-learning)的方法,通过在多个小任务上学习以提高模型的泛化能力。
  3. 另外,跨语言迁移学习也是一种常见的策略,特别适用于多语言环境下的NLP任务。通过在一个语言上训练模型,然后将其应用到另一个语言的任务上,在5-shot分类中也可以发挥巨大作用。


2.2 数据集

AG_NEWS共包含120000条训练样本集(train.csv), 7600测试样本数据集(test.csv)。每个类别分别拥有 30000 个训练样本及 1900 个测试样本。

2.1 BERT ( Bidirectional Encoder Representations from Transformers )

BERT是由 Google 在 2018 年提出的一种预训练语言模型。与传统的语言模型只能从左到右或者从右到左单向预测下一个词不同, BERT 使用了 Transformer 模型,并且在预训练阶段使用了双向的上下文信息。BERT 的预训练分为两个阶段: Masked Language Model ( MLM )和 Next Sentence Prediction( NSP )。在 MLM 阶段,BERT 在输入的句子中随机遮盖一些词汇,然后通过上下文信息来预测这些被遮盖的词汇。在 NSP 阶段, BERT 输入一对句子,并判断这两个句子是否是连续的。在预训练完成后, BERT 可以进行下游任务的微调,如文本分类、命名实体识别、自然语言推理等。通过微调, BERT 可以将其学习到的语言表示应用于各种自然语言处理任务中。BERT 的优点包括:

  • 预训练阶段使用了双向上下文信息,有助于更好地理解句子中的语义和语境。
  • BERT 可以用于各种下游任务,只需要微调即可,无需针对每个任务重新训练模型。
  • BERT 在许多自然语言处理任务上取得了优异的性能,甚至超过了以往的模型。



import numpy as np
import pandas as pd
import sklearn
from simpletransformers.classification import ClassificationModel

def train(train_file, test_file):
    # Reading the train and test files
    train_df = pd.read_csv(train_file)
    test_df = pd.read_csv(test_file)
    # torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url
    # using PyTorch and BERT and hence we are not doing any text preprocessing
    df = pd.DataFrame()
    df['text'] = train_df['Description']
    df['label'] = train_df['Class Index']
    # Since the labels are starting from 1 to N - we need to map it to 0 to N-1
    df['label'] = df['label'].apply(lambda x : x -1)
    # 5 represents the number of samples extracted from each class for training
    small_df = df.groupby('label').apply(lambda x: x.sample(5, replace=True)).reset_index(drop=True)
    # Configure the simple transformer for classificating the text
    # select the bert model you want to train
    model = ClassificationModel('bert', 'bert-base-cased', num_labels=4, 
        args={'reprocess_input_data':True, 'overwrite_output_dir':True, 'num_train_epochs':80, 'learning_rate':5e-5,
              'train_batch_size':20, 'eval_batch_size':20 }, 
    # Lets begin our training of the model on a smaller dataset

    # Lets prepare our evalution data set
    dt = pd.DataFrame()
    dt['text'] = test_df['Description']
    dt['label'] = test_df['Class Index']
    dt['label'] = dt['label'].apply(lambda x : x -1)
    small_dt = dt.groupby('label').apply(lambda x: x.sample(5, replace=True)).reset_index(drop=True)
    # Evaluate the model
    result, model_outputs, wrong_predictions = model.eval_model(small_dt)
    predicted = []
    for arr in model_outputs:
    true = small_dt['label'].tolist()
    print(sklearn.metrics.classification_report(true, predicted, target_names=['World','Sports','Business','Sci/Tech']))

if __name__=="__main__":
    train_file = "data/train.csv"
    test_file  = "data/test.csv"
    train(train_file, test_file)
AG's News Topic Classification Dataset Version 3, Updated 09/09/2015 ORIGIN AG is a collection of more than 1 million news articles. News articles have been gathered from more than 2000 news sources by ComeToMyHead in more than 1 year of activity. ComeToMyHead is an academic news search engine which has been running since July, 2004. The dataset is provided by the academic comunity for research purposes in data mining (clustering, classification, etc), information retrieval (ranking, search, etc), xml, data compression, data streaming, and any other non-commercial activity. For more information, please refer to the link . The AG's news topic classification dataset is constructed by Xiang Zhang ( from the dataset above. It is used as a text classification benchmark in the following paper: Xiang Zhang, Junbo Zhao, Yann LeCun. Character-level Convolutional Networks for Text Classification. Advances in Neural Information Processing Systems 28 (NIPS 2015). DESCRIPTION The AG's news topic classification dataset is constructed by choosing 4 largest classes from the original corpus. Each class contains 30,000 training samples and 1,900 testing samples. The total number of training samples is 120,000 and testing 7,600. The file classes.txt contains a list of classes corresponding to each label. The files train.csv and test.csv contain all the training samples as comma-sparated values. There are 3 columns in them, corresponding to class index (1 to 4), title and description. The title and description are escaped using double quotes ("), and any internal double quote is escaped by 2 double quotes (""). New lines are escaped by a backslash followed with an "n" character, that is "\n".


