AG_NEWS文本分类任务

1. 数据集准备及预处理

1.1 本实验使用的数据集为AG_NEWS

AG_NEWS:新闻语料库,包含4个大类新闻:World、Sports、Business、Sci/Tec。AG_NEWS共包含120000条训练样本集(train.csv), 7600测试样本数据集(test.csv)。每个类别分别拥有 30000 个训练样本及 1900 个测试样本。

数据集文件内容如下:

类别序号是1、2、3、4对应着World、Sports、Business、Sci/Tec

1.2 对训练集构建词汇表:

这过程很简单,对每个词进行统计,并对应一个索引即可,使用了torchtext

提供的get_tokenizer,build_vocab_from_iterator两个功能,一个是分词,一个是构建词汇表,具体实现在bulvocab函数里

1.3 分析训练集文本句长情况:

最长207,最短12,平均长43.280075,众长40

训练集文本句长频数分布图

1.4 本实验对batch的构建,采取的句长方式分别为,TextCNN使用均长44,RNN类模型使用batch内局部均长

batch内局部均长:具体就是在使用DataLoader构建batch时,就对每个batch内的句长进行平均,如果短就用<PAD>进行填充,<PAD>对应词汇表的‘0’,具体在dateset2loader函数里

def dateset2loader(config, vocab, traindata, testdata):
    tokenizer = get_tokenizer('basic_english')  # 基本的英文分词器,tokenizer会把句子进行分割,类似jieba
    # Step3 构建数据加载器 dataloader
    ##########################################################################

    print("Step3: DateSet -> Dataloader")
    ##########################################################################
    # text_pipeline将一个文本字符串转换为整数List, List中每项对应词汇表voca中的单词的索引号
    text_pipeline = lambda x: vocab(tokenizer(x))

    # label_pipeline将label转换为整数
    label_pipeline = lambda x: int(x) - 1

    # 加载数据集合,转换为张量
    def collate_batch(batch):
        """
        (3, "Wall") -> (2, "467")
        :param batch:
        :return:
        """
        label_list, text_list = [], []
        for (_label, _text) in batch:
            label_list.append(label_pipeline(_label))
            processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)
            text_list.append(processed_text)
        # 指定句子长度统一的标准
        if config.seq_mode == "min":
            seq_len = min(len(item) for item in text_list)
        elif config.seq_mode == "max":
            seq_len = max(len(item) for item in text_list)
        elif config.seq_mode == "avg":
            seq_len = sum(len(item) for item in text_list) / len(text_list)
        elif isinstance(config.seq_mode, int):
            seq_len = config.seq_mode
        else:
            seq_len = min(len(item) for item in text_list)
        seq_len = int(seq_len)
        # 每一个batch里统一长度
        batch_seq = torch.stack(tensor_padding(text_list, seq_len))
        label_list = torch.tensor(label_list, dtype=torch.int64)
        return label_list, batch_seq

    train_dataloader = DataLoader(traindata, batch_size=config.batchsize, shuffle=True, collate_fn=collate_batch)
    test_dataloader = DataLoader(testdata, batch_size=config.batchsize, shuffle=True, collate_fn=collate_batch)
    return train_dataloader, test_dataloader

2. 模型构建

模型使用pytorch提供的torch.nn.RNN,torch.nn.LSTM,torch.nn.GRU以及torch.nn.Conv2

2.1 TextCnn

采用的是跟原版TextCnn一样的结构,仅最后2分类改为4分类

2.2 RNN类模型

纯RNN:

RNNnet(

  (embedding): Embedding(95811, 128)

  (rnn): RNN(128, 256, batch_first=True)

  (fc): Linear(in_features=256, out_features=4, bias=True)

)

LSTM:

RNNnet(

  (embedding): Embedding(95811, 128)

  (rnn): LSTM(128, 256, batch_first=True)

  (fc): Linear(in_features=256, out_features=4, bias=True)

)

GRU:

RNNnet(

  (embedding): Embedding(95811, 128)

  (rnn): GRU(128, 256, batch_first=True)

  (fc): Linear(in_features=256, out_features=4, bias=True)

)

3. 训练过程搭建

3.1 首先开启device,如果能用cuda就用,不能就使用cpu

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

3.2 然后定义模型

if config.mode == 'cnn':
    model = textCnn(
        len_vocab=len_vocab,
        embedding_size=config.embedding_size,
        sen_len=config.seq_mode,
        num_class=len(classes),
    )
else:
    model = RNNnet(
        len_vocab=len_vocab,
        embedding_size=config.embedding_size,
        hidden_size=config.hidden_size,
        num_class=len(classes),
        num_layers=config.num_layers,
        mode=config.mode
    )

3.3 优化器采用Adam,损失函数使用交叉熵

optimizer = torch.optim.Adam(model.parameters(), lr=config.l_r)
loss_fn = nn.CrossEntropyLoss()

3.4 然后就是遍历data,开始训练,本实验每次epoch之后都进行一次测试集的验证,根据测试集的准确度来保存最优模型

4. 模型评价

4.1 所有超参数如下(均在cpu下训练):

embedding_size

hidden_size(RNN类型)

num_layers

(RNN类型)

l_r

epochs

batchsize

128

256

1

1e-3

50

1024

4.2 句长设置

TextCnn

RNN类型

44

batch内局部均长

4.3 最终模型结果,分别在训练集和测试集上的准确度

TextCNN

纯RNN

Train_set: 95.36%

Test_set: 86.5%

Train_set: 96.62%

Test_set: 89.75%

LSTM

GRU

Train_set: 99.83%

Test_set: 90.67%

Train_set: 99.85%

Test_set: 90.46%

 4.4 每epoch损失值情况:

TextCNN

纯RNN

LSTM

GRU

4.5 每epoch训练集准确度情况:

TextCNN

纯RNN

LSTM

GRU

4.6 每epoch测试集准确度情况:

TextCNN

纯RNN

LSTM

GRU

4.7 随机陌生样本进行测试(这是网上找的一个体育类新闻,然后使用谷歌翻译翻译成英文):

北京时间4月17日,NBA附加赛,国王118-94大胜勇士,勇士被淘汰出局无缘季后赛,国王将与鹈鹕争夺最后一个季后赛名额。国王:基根-穆雷32分9篮板、福克斯24分4篮板6助攻、巴恩斯17分4篮板3助攻、萨博尼斯16分12篮板7助攻、埃利斯15分4篮板5助攻3抢断3盖帽。勇士:库里22分4篮板、库明加16分7篮板、穆迪16分3篮板、维金斯12分3篮板、追梦12分3篮板6助攻、保罗3分2助攻、克雷-汤普森10投0中一分未得。首节比赛,国王三分手感火热,领先勇士9分结束第一节。第二节,国王打出高潮将比分拉开,勇士节末不断追分。半场结束时,勇士落后国王4分。第三节,勇士一度将分差追到仅剩1分,国王延续三分手感再度拉开分差,国王领先15分进入最后一节。末节,国王将分差继续拉大,勇士崩盘,分差超过20分。最终,国王118-94大胜勇士。

On April 17, Beijing time, in the NBA play-offs, the Kings defeated the Warriors 118-94, the Warriors were eliminated from the playoffs, and the Kings will compete with the Pelicans for the last playoff spot. Kings: Keegan Murray 32 points and 9 rebounds, Fox 24 points, 4 rebounds and 6 assists, Barnes 17 points, 4 rebounds and 3 assists, Sabonis 16 points, 12 rebounds and 7 assists, Ellis 15 points, 4 rebounds, 5 assists, 3 steals and 3 blocks. Warriors: Curry 22 points and 4 rebounds, Kuminga 16 points and 7 rebounds, Moody 16 points and 3 rebounds, Wiggins 12 points and 3 rebounds, Dream 12 points, 3 rebounds and 6 assists, Paul 3 points and 2 assists, Klay Thompson did not score a point on 0-of-10 shooting. In the first quarter of the game, the Kings were hot with a three-point hand, leading the Warriors by 9 points to end the first quarter. In the second quarter, the Kings played a climax to pull the score away, and the Warriors continued to chase points at the end of the quarter. At halftime, the Warriors trailed the Kings by four points. In the third quarter, the Warriors once chased the difference to only one point, and the King continued to open the gap again with a three-point hand, and the King led by 15 points into the final quarter. In the final quarter, the Kings continued to widen the difference, and the Warriors collapsed, with a margin of more than 20 points. In the end, the Kings defeated the Warriors 118-94.

TextCNN

纯RNN

Sports

Sports

LSTM

GRU

Sports

Sports

5. 模型对比

5.1 损失对比:

5.2 训练集准确度对比:

5.3 测试集准确度对比:

完整代码

在此链接诗意画/AG_NEWS text classification task

  • 13
    点赞
  • 13
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
496,835 条来自 AG 新闻语料库 4 大类别超过 2000 个新闻源的新闻文章,数据集仅仅援用了标题和描述字段。每个类别分别拥有 30,000 个训练样本及 1900 个测试样本。 README: AG's News Topic Classification Dataset Version 3, Updated 09/09/2015 ORIGIN AG is a collection of more than 1 million news articles. News articles have been gathered from more than 2000 news sources by ComeToMyHead in more than 1 year of activity. ComeToMyHead is an academic news search engine which has been running since July, 2004. The dataset is provided by the academic comunity for research purposes in data mining (clustering, classification, etc), information retrieval (ranking, search, etc), xml, data compression, data streaming, and any other non-commercial activity. For more information, please refer to the link http://www.di.unipi.it/~gulli/AG_corpus_of_news_articles.html . The AG's news topic classification dataset is constructed by Xiang Zhang (xiang.zhang@nyu.edu) from the dataset above. It is used as a text classification benchmark in the following paper: Xiang Zhang, Junbo Zhao, Yann LeCun. Character-level Convolutional Networks for Text Classification. Advances in Neural Information Processing Systems 28 (NIPS 2015). DESCRIPTION The AG's news topic classification dataset is constructed by choosing 4 largest classes from the original corpus. Each class contains 30,000 training samples and 1,900 testing samples. The total number of training samples is 120,000 and testing 7,600. The file classes.txt contains a list of classes corresponding to each label. The files train.csv and test.csv contain all the training samples as comma-sparated values. There are 3 columns in them, corresponding to class index (1 to 4), title and description. The title and description are escaped using double quotes ("), and any internal double quote is escaped by 2 double quotes (""). New lines are escaped by a backslash followed with an "n" character, that is "\n".

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值