AG_NEWS数据集文本分类实战(一)

AG_NEWS数据集文本分类实战(一)

一、数据集加载

我们使用AG_NEWS数据集实现一个简单的文本分类模型(Text Classification)。首先,这是我们要用到的库。

import torch
import torch.nn as nn
import torchtext
from torchtext.datasets import AG_NEWS
import os
from collections import Counter, OrderedDict

然后我们加载训练集与测试集。

os.makedirs('./data',exist_ok=True)
train_dataset, test_dataset = AG_NEWS(root='./data', split=('train', 'test'))
classes = ['World', 'Sports', 'Business', 'Sci/Tech']

这里也贴一下训练集与测试集的下载链接,如果上述方式无法下载的话。

https://raw.githubusercontent.com/mhjabreel/CharCnn_Keras/master/data/ag_news_csv/train.csv
https://raw.githubusercontent.com/mhjabreel/CharCnn_Keras/master/data/ag_news_csv/test.csv

我们试着输出训练集前五个样本:

for i,x in zip(range(5),train_dataset):
    print(f"**{classes[x[0]-1]}** -> {x[1]}\n")

输出结果为

**Business** -> Wall St. Bears Claw Back Into the Black (Reuters) Reuters - Short-sellers, Wall Street's dwindling\band of ultra-cynics, are seeing green again.

**Business** -> Carlyle Looks Toward Commercial Aerospace (Reuters) Reuters - Private investment firm Carlyle Group,\which has a reputation for making well-timed and occasionally\controversial plays in the defense industry, has quietly placed\its bets on another part of the market.

**Business** -> Oil and Economy Cloud Stocks' Outlook (Reuters) Reuters - Soaring crude prices plus worries\about the economy and the outlook for earnings are expected to\hang over the stock market next week during the depth of the\summer doldrums.

**Business** -> Iraq Halts Oil Exports from Main Southern Pipeline (Reuters) Reuters - Authorities have halted oil export\flows from the main pipeline in southern Iraq after\intelligence showed a rebel militia could strike\infrastructure, an oil official said on Saturday.

**Business** -> Oil prices soar to all-time record, posing new menace to US economy (AFP) AFP - Tearaway world oil prices, toppling records and straining wallets, present a new economic menace barely three months before the US presidential elections.

把训练集和测试集转化为列表对象

train_dataset = list(train_dataset)
test_dataset = list(test_dataset)

二、词表与DataLoader的构建

选择分词器

tokenizer = torchtext.data.utils.get_tokenizer('basic_english')

构建词表:

counter = Counter()
for (label, line) in train_dataset:
    counter.update(tokenizer(line))
# 构建按词频降序排列的列表
order_dict =OrderedDict(sorted(counter.items(), key=lambda x:x[1], reverse=True))
# 构建以word-indice为键值对的词表
vocab = torchtext.vocab.vocab(order_dict, min_freq=1)
vocab_size = len(vocab)

我们知道在处理文本数据时,每一个句子长度是可能在变化的,我们可以在句尾增加padding,从而补齐。

# 处理文本时要注意,句子长度可能会变化,我们可以填充为最大长度
def padify(b):
    # b is the list of tuples of length batch_size
    #   - first element of a tuple = label, 
    #   - second = feature (text sequence)
    # build vectorized sequence
    v = [vocab.lookup_indices(tokenizer(x[1])) for x in b]
    # first, compute max length of a sequence in this minibatch
    l = max(map(len,v))
    return ( # tuple of two tensors - labels and features
        torch.LongTensor([t[0]-1 for t in b]),
        torch.stack([torch.nn.functional.pad(torch.tensor(t),(0,l-len(t)),mode='constant',value=0) for t in v]))

下面是训练集的dataloader

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=16, collate_fn=padify, shuffle=True)

三、模型构建

由于本次是文本分类实战的第一篇,我们使用一个极为简单的模型,在word embedding后经过一个全连接层,即将一个sentence的各个词向量取一个平均,再通过一个fully-connected层并进行softmax操作。

class EmbedClassifier(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_class):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.fc = nn.Linear(embed_dim, num_class)
    def forward(self, x):
        x = self.embedding(x)
        x = torch.mean(x, dim=1)  # 把sentence中各个词向量取平均
        return self.fc(x)

下面构建训练函数,损失函数用交叉熵函数,优化算法使用Adam方法。

def train_epoch(net,dataloader,lr,optimizer=None,loss_fn = None,epoch_size=None, report_freq=200):
    optimizer = optimizer or torch.optim.Adam(net.parameters(),lr=lr)
    loss_fn = loss_fn or nn.CrossEntropyLoss()
    net.train()  # 训练模式
    total_loss, acc, count, i = 0, 0, 0, 0
    for labels, features in dataloader:
        optimizer.zero_grad()
        out = net(features)
        loss = loss_fn(out,labels) #cross_entropy(out,labels),交叉熵函数自带了softmax运算
        loss.backward()
        optimizer.step()
        total_loss+=loss
        predicted = torch.argmax(out,1)  # 1指每一列的最大值
        acc+=(predicted==labels).sum()
        count+=len(labels)
        i+=1
        if i%report_freq==0:
            print(f"{count}: acc={acc.item()/count}")
        if epoch_size and count>epoch_size:
            break
    print(f'loss is {total_loss.item()/count}')
    print(f'accuracy is {acc.item()/count*100}%')

下面训练我们的模型实例

net = EmbedClassifier(vocab_size, 32, len(classes))
train_epoch(net,train_loader, lr=1, epoch_size=25000)

输出结果为

3200: acc=0.641875
6400: acc=0.67984375
9600: acc=0.7048958333333334
12800: acc=0.71765625
16000: acc=0.728375
19200: acc=0.7410416666666667
22400: acc=0.7509375
loss is 0.9129715990882917
accuracy is 75.58381317978247%

可见经过一个epoch训练后,在训练集上准确率达到了75%左右,这样一个简单的模型性能还是不错的。

  • 4
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
AG's News Topic Classification Dataset Version 3, Updated 09/09/2015 ORIGIN AG is a collection of more than 1 million news articles. News articles have been gathered from more than 2000 news sources by ComeToMyHead in more than 1 year of activity. ComeToMyHead is an academic news search engine which has been running since July, 2004. The dataset is provided by the academic comunity for research purposes in data mining (clustering, classification, etc), information retrieval (ranking, search, etc), xml, data compression, data streaming, and any other non-commercial activity. For more information, please refer to the link http://www.di.unipi.it/~gulli/AG_corpus_of_news_articles.html . The AG's news topic classification dataset is constructed by Xiang Zhang (xiang.zhang@nyu.edu) from the dataset above. It is used as a text classification benchmark in the following paper: Xiang Zhang, Junbo Zhao, Yann LeCun. Character-level Convolutional Networks for Text Classification. Advances in Neural Information Processing Systems 28 (NIPS 2015). DESCRIPTION The AG's news topic classification dataset is constructed by choosing 4 largest classes from the original corpus. Each class contains 30,000 training samples and 1,900 testing samples. The total number of training samples is 120,000 and testing 7,600. The file classes.txt contains a list of classes corresponding to each label. The files train.csv and test.csv contain all the training samples as comma-sparated values. There are 3 columns in them, corresponding to class index (1 to 4), title and description. The title and description are escaped using double quotes ("), and any internal double quote is escaped by 2 double quotes (""). New lines are escaped by a backslash followed with an "n" character, that is "\n".

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值