AG_NEWS数据集文本分类实战(一)

AG_NEWS数据集文本分类实战(一)

一、数据集加载

我们使用AG_NEWS数据集实现一个简单的文本分类模型(Text Classification)。首先,这是我们要用到的库。

import torch
import torch.nn as nn
import torchtext
from torchtext.datasets import AG_NEWS
import os
from collections import Counter, OrderedDict

然后我们加载训练集与测试集。

os.makedirs('./data',exist_ok=True)
train_dataset, test_dataset = AG_NEWS(root='./data', split=('train', 'test'))
classes = ['World', 'Sports', 'Business', 'Sci/Tech']

这里也贴一下训练集与测试集的下载链接,如果上述方式无法下载的话。

https://raw.githubusercontent.com/mhjabreel/CharCnn_Keras/master/data/ag_news_csv/train.csv
https://raw.githubusercontent.com/mhjabreel/CharCnn_Keras/master/data/ag_news_csv/test.csv

我们试着输出训练集前五个样本:

for i,x in zip(range(5),train_dataset):
    print(f"**{classes[x[0]-1]}** -> {x[1]}\n")

输出结果为

**Business** -> Wall St. Bears Claw Back Into the Black (Reuters) Reuters - Short-sellers, Wall Street's dwindling\band of ultra-cynics, are seeing green again.

**Business** -> Carlyle Looks Toward Commercial Aerospace (Reuters) Reuters - Private investment firm Carlyle Group,\which has a reputation for making well-timed and occasionally\controversial plays in the defense industry, has quietly placed\its bets on another part of the market.

**Business** -> Oil and Economy Cloud Stocks' Outlook (Reuters) Reuters - Soaring crude prices plus worries\about the economy and the outlook for earnings are expected to\hang over the stock market next week during the depth of the\summer doldrums.

**Business** -> Iraq Halts Oil Exports from Main Southern Pipeline (Reuters) Reuters - Authorities have halted oil export\flows from the main pipeline in southern Iraq after\intelligence showed a rebel militia could strike\infrastructure, an oil official said on Saturday.

**Business** -> Oil prices soar to all-time record, posing new menace to US economy (AFP) AFP - Tearaway world oil prices, toppling records and straining wallets, present a new economic menace barely three months before the US presidential elections.

把训练集和测试集转化为列表对象

train_dataset = list(train_dataset)
test_dataset = list(test_dataset)

二、词表与DataLoader的构建

选择分词器

tokenizer = torchtext.data.utils.get_tokenizer('basic_english')

构建词表:

counter = Counter()
for (label, line) in train_dataset:
    counter.update(tokenizer(line))
# 构建按词频降序排列的列表
order_dict =OrderedDict(sorted(counter.items(), key=lambda x:x[1], reverse=True))
# 构建以word-indice为键值对的词表
vocab = torchtext.vocab.vocab(order_dict, min_freq=1)
vocab_size = len(vocab)

我们知道在处理文本数据时,每一个句子长度是可能在变化的,我们可以在句尾增加padding,从而补齐。

# 处理文本时要注意,句子长度可能会变化,我们可以填充为最大长度
def padify(b):
    # b is the list of tuples of length batch_size
    #   - first element of a tuple = label, 
    #   - second = feature (text sequence)
    # build vectorized sequence
    v = [vocab.lookup_indices(tokenizer(x[1])) for x in b]
    # first, compute max length of a sequence in this minibatch
    l = max(map(len,v))
    return ( # tuple of two tensors - labels and features
        torch.LongTensor([t[0]-1 for t in b]),
        torch.stack([torch.nn.functional.pad(torch.tensor(t),(0,l-len(t)),mode='constant',value=0) for t in v]))

下面是训练集的dataloader

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=16, collate_fn=padify, shuffle=True)

三、模型构建

由于本次是文本分类实战的第一篇,我们使用一个极为简单的模型,在word embedding后经过一个全连接层,即将一个sentence的各个词向量取一个平均,再通过一个fully-connected层并进行softmax操作。

class EmbedClassifier(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_class):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.fc = nn.Linear(embed_dim, num_class)
    def forward(self, x):
        x = self.embedding(x)
        x = torch.mean(x, dim=1)  # 把sentence中各个词向量取平均
        return self.fc(x)

下面构建训练函数,损失函数用交叉熵函数,优化算法使用Adam方法。

def train_epoch(net,dataloader,lr,optimizer=None,loss_fn = None,epoch_size=None, report_freq=200):
    optimizer = optimizer or torch.optim.Adam(net.parameters(),lr=lr)
    loss_fn = loss_fn or nn.CrossEntropyLoss()
    net.train()  # 训练模式
    total_loss, acc, count, i = 0, 0, 0, 0
    for labels, features in dataloader:
        optimizer.zero_grad()
        out = net(features)
        loss = loss_fn(out,labels) #cross_entropy(out,labels),交叉熵函数自带了softmax运算
        loss.backward()
        optimizer.step()
        total_loss+=loss
        predicted = torch.argmax(out,1)  # 1指每一列的最大值
        acc+=(predicted==labels).sum()
        count+=len(labels)
        i+=1
        if i%report_freq==0:
            print(f"{count}: acc={acc.item()/count}")
        if epoch_size and count>epoch_size:
            break
    print(f'loss is {total_loss.item()/count}')
    print(f'accuracy is {acc.item()/count*100}%')

下面训练我们的模型实例

net = EmbedClassifier(vocab_size, 32, len(classes))
train_epoch(net,train_loader, lr=1, epoch_size=25000)

输出结果为

3200: acc=0.641875
6400: acc=0.67984375
9600: acc=0.7048958333333334
12800: acc=0.71765625
16000: acc=0.728375
19200: acc=0.7410416666666667
22400: acc=0.7509375
loss is 0.9129715990882917
accuracy is 75.58381317978247%

可见经过一个epoch训练后,在训练集上准确率达到了75%左右,这样一个简单的模型性能还是不错的。

AG's News Topic Classification Dataset Version 3, Updated 09/09/2015 ORIGIN AG is a collection of more than 1 million news articles. News articles have been gathered from more than 2000 news sources by ComeToMyHead in more than 1 year of activity. ComeToMyHead is an academic news search engine which has been running since July, 2004. The dataset is provided by the academic comunity for research purposes in data mining (clustering, classification, etc), information retrieval (ranking, search, etc), xml, data compression, data streaming, and any other non-commercial activity. For more information, please refer to the link http://www.di.unipi.it/~gulli/AG_corpus_of_news_articles.html . The AG's news topic classification dataset is constructed by Xiang Zhang (xiang.zhang@nyu.edu) from the dataset above. It is used as a text classification benchmark in the following paper: Xiang Zhang, Junbo Zhao, Yann LeCun. Character-level Convolutional Networks for Text Classification. Advances in Neural Information Processing Systems 28 (NIPS 2015). DESCRIPTION The AG's news topic classification dataset is constructed by choosing 4 largest classes from the original corpus. Each class contains 30,000 training samples and 1,900 testing samples. The total number of training samples is 120,000 and testing 7,600. The file classes.txt contains a list of classes corresponding to each label. The files train.csv and test.csv contain all the training samples as comma-sparated values. There are 3 columns in them, corresponding to class index (1 to 4), title and description. The title and description are escaped using double quotes ("), and any internal double quote is escaped by 2 double quotes (""). New lines are escaped by a backslash followed with an "n" character, that is "\n".
496,835 条来自 AG 新闻语料库 4 大类别超过 2000 个新闻源的新闻文章,数据集仅仅援用了标题和描述字段。每个类别分别拥有 30,000 个训练样本及 1900 个测试样本。 README: AG's News Topic Classification Dataset Version 3, Updated 09/09/2015 ORIGIN AG is a collection of more than 1 million news articles. News articles have been gathered from more than 2000 news sources by ComeToMyHead in more than 1 year of activity. ComeToMyHead is an academic news search engine which has been running since July, 2004. The dataset is provided by the academic comunity for research purposes in data mining (clustering, classification, etc), information retrieval (ranking, search, etc), xml, data compression, data streaming, and any other non-commercial activity. For more information, please refer to the link http://www.di.unipi.it/~gulli/AG_corpus_of_news_articles.html . The AG's news topic classification dataset is constructed by Xiang Zhang (xiang.zhang@nyu.edu) from the dataset above. It is used as a text classification benchmark in the following paper: Xiang Zhang, Junbo Zhao, Yann LeCun. Character-level Convolutional Networks for Text Classification. Advances in Neural Information Processing Systems 28 (NIPS 2015). DESCRIPTION The AG's news topic classification dataset is constructed by choosing 4 largest classes from the original corpus. Each class contains 30,000 training samples and 1,900 testing samples. The total number of training samples is 120,000 and testing 7,600. The file classes.txt contains a list of classes corresponding to each label. The files train.csv and test.csv contain all the training samples as comma-sparated values. There are 3 columns in them, corresponding to class index (1 to 4), title and description. The title and description are escaped using double quotes ("), and any internal double quote is escaped by 2 double quotes (""). New lines are escaped by a backslash followed with an "n" character, that is "\n".
### AG News 数据集概述 AG新闻(AG’s News)是个广泛用于文本分类任务的数据集,特别适用于主题分类。该数据集由来自超过20万篇新闻文章的标题和描述组成,涵盖了四个主要类别:世界、体育、商业和技术[^3]。这些类别的划分使得它可以作为个经典的多分类问题被研究。 #### 数据结构 AG新闻数据集通常分为两个版本: - **AG's News (原始版)**: 包含完整的新闻内容。 - **AG's News Simplified**: 只保留了标题部分,减少了数据规模以便于快速实验。 每个样本都由两部分构成:标签(即所属类别)以及对应的新闻正文或标题。以下是其典型样例格式: | Class Index | Class Name | Title/Description | |-------------|---------------|----------------------------------------------------------------------------------| | 1 | World | A new threat from North Korea has prompted the United States... | | 2 | Sports | In a stunning upset, the underdog beat the world champion... | #### 下载地址 官方并未提供直接的下载页面,但可以通过以下资源获取到此数据集: - GitHub上的第三方整理项目提供了方便的访问方式。例如,在[kaggle](https://www.kaggle.com/)网站上有许多用户上传并分享了经过处理后的AG新闻数据集文件[^4]。 另外也可以通过些学术论文附带材料找到链接或者利用Python库`torchtext`内置支持加载这个数据集合: ```python from torchtext.datasets import AG_NEWS train_iter = AG_NEWS(split='train') next(train_iter) ``` 以上代码片段展示了如何借助PyTorch框架下的子模块轻松读取训练集中的条记录[^5]。 ### 技术细节补充说明 对于希望深入理解该数据集特性的研究人员来说,了解以下几个方面可能有所帮助: - 文本长度分布情况; - 各类别间平衡程度统计指标等. 如果计划应用于深度学习模型,则需注意预处理环节的重要性,比如去除停用词(stopwords removal),执行词干提取(stemming)或是lemmatization操作等等[^6].
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值