AG_NEWS数据集的编码分类和预测

##nlp入门小白,这个数据集有很多下载资源可以网上找一下下载到本地,结合了好多博主的东西
path = r'自己的地址datasets\AG_NEWS.data\datasets\AG_NEWS'

import torch
from torchtext.datasets import AG_NEWS

# train_iter = AG_NEWS(root=path, split='train')
#####
import pandas as pd
def load_data(csv_file):
    df = pd.read_csv(csv_file, header=None)  # pd默认第一行不读取,所以添加 header
    dataTmep = []

    # 逐行读取,_ 行号,row 内容
    for _, row in df.iterrows():
        label = row[0]
        context = row[1] + row[2]  # 将标题,内容合并
        dataTmep.append((label, context))
    return dataTmep
train_iter=train_dataset = load_data(r"D:\study\dataset\AG News\train.csv")
test_iter=test_dataset=load_data(r"D:\study\dataset\AG News\test.csv")
from torchtext.data.utils import get_tokenizer
from collections import Counter
from torchtext.vocab import vocab

tokenizer = get_tokenizer('basic_english')
counter = Counter()
for (label, line) in train_iter:
    counter.update(tokenizer(line))
#这我自己添加的,后面预测一直有报错加了就没了
for (label, line) in test_iter:
    counter.update(tokenizer(line))
vocab = vocab(counter, min_freq=1)


text_pipeline = lambda x: [vocab[token] for token in tokenizer(x)]

label_pipeline = lambda x: int(x) - 1


from torch.utils.data import DataLoader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def collate_batch(batch):
    label_list, text_list, offsets = [], [], [0]
    for (_label, _text) in batch:
        label_list.append(label_pipeline(_label))
        processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)
        text_list.append(processed_text)
        offsets.append(processed_text.size(0))

    label_list = torch.tensor(label_list, dtype=torch.int64)
    offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)
    text_list = torch.cat(text_list)
    return label_list.to(device), text_list.to(device), offsets.to(device)


from torch import nn
class TextClassificationModel(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_class):
        super(TextClassificationModel, self).__init__()
        self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=True)
        self.fc = nn.Linear(embed_dim, num_class)
        self.init_weights()

    def init_weights(self):
        initrange = 0.5
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.fc.weight.data.uniform_(-initrange, initrange)
        self.fc.bias.data.zero_()

    def forward(self, text, offsets):
        embedded = self.embedding(text, offsets)
        output = self.fc(embedded)
        return output


num_class = 4
vocab_size = len(vocab)
emsize = 64
model = TextClassificationModel(vocab_size, emsize, num_class).to(device)


import time
def train(dataloader):
    model.train()
    total_acc, total_count = 0, 0
    log_interval = 500
    start_time = time.time()

    for idx, (label, text, offsets) in enumerate(dataloader):
        optimizer.zero_grad()
        predited_label = model(text, offsets)
        loss = criterion(predited_label, label)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
        optimizer.step()
        total_acc += (predited_label.argmax(1) == label).sum().item()
        total_count += label.size(0)
        if idx % log_interval == 0 and idx > 0:
            elapsed = time.time() - start_time
            print('| epoch {:3d} | {:5d}/{:5d} batches, accuracy {:8.3f}'.format(epoch, idx, len(dataloader), total_acc / total_count))
            total_acc, total_count = 0, 0
            start_time = time.time()


def evaluate(dataloader):
    model.eval()
    total_acc, total_count = 0, 0

    with torch.no_grad():
        for idx, (label, text, offsets) in enumerate(dataloader):
            predited_label = model(text, offsets)

            total_acc += (predited_label.argmax(1) == label).sum().item()
            total_count += label.size(0)
    return total_acc / total_count


def predict(text, text_pipeline):
    with torch.no_grad():
        text = torch.tensor(text_pipeline(text))
        output = model(text.to(device), torch.tensor([0]).to(device))#将输入张量移动到与模型相同的设备上
        # output = model(text, torch.tensor([0]))
        return output.argmax(1).item() + 1


from torch.utils.data.dataset import random_split

if __name__ == '__main__':

    EPOCHS = 10
    LR = 5
    BATCH_SIZE = 64
    #
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=LR)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)
    total_accu = None
    # train_iter, test_iter = AG_NEWS(root=path)
    train_dataset = load_data(r"D:\study\dataset\AG News\train.csv")
    test_dataset = load_data(r"D:\study\dataset\AG News\test.csv")
    num_train = int(len(train_dataset) * 0.95)
    split_train_, split_valid_ = random_split(train_dataset, [num_train, len(train_dataset) - num_train])
    #
    train_dataloader = DataLoader(split_train_, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)      # shuffle表示随机打乱
    valid_dataloader = DataLoader(split_valid_, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
    test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
    #
    for epoch in range(1, EPOCHS + 1):
        epoch_start_time = time.time()
        train(train_dataloader)
        accu_val = evaluate(valid_dataloader)
        if total_accu is not None and total_accu > accu_val:
            scheduler.step()
        else:
            total_accu = accu_val
        print('-' * 59)
        print('| end of epoch {:3d} | time: {:5.2f}s | '
              'valid accuracy {:8.3f} '.format(epoch, time.time() - epoch_start_time, accu_val))
        print('-' * 59)
    #
    #
    print('Checking the results of test dataset.')

    accu_test = evaluate(test_dataloader)
    print('test accuracy {:8.3f}'.format(accu_test))
    #
    torch.save(model.state_dict(), r'D:\study\dataset\AG News\model_TextClassification.pth')
#以下是预测内容
    # model.load_state_dict(torch.load(r'D:\study\dataset\AG News\model_TextClassification.pth'))
    # #
    # ag_news_label = {1: "World",
    #                  2: "Sports",
    #                  3: "Business",
    #                  4: "Sci/Tec"}
    #
    # # ex_text_str = "MEMPHIS, Tenn. – Four days ago, Jon Rahm was enduring the season’s worst weather conditions on Sunday at The Open on his way to a closing 75 at Royal Portrush, which considering the wind and the rain was a respectable showing. Thursday’s first round at the WGC-FedEx St. Jude Invitational was another story. With temperatures in the mid-80s and hardly any wind, the Spaniard was 13 strokes better in a flawless round. Thanks to his best putting performance on the PGA Tour, Rahm finished with an 8-under 62 for a three-stroke lead, which was even more impressive considering he’d never played the front nine at TPC Southwind."
    # ex_text_str = 'Beijing of Automation, Beijing Institute of Technology'
    # # model = model.to("cpu")
    #
    # print("This is a %s news" % ag_news_label[predict(ex_text_str, text_pipeline)])

#原博 参考博   不完全一样做了一下修改,我运行的时候虽然下载到本地但还是那什么~【PyTorch】7 文本分类TorchText实战——AG_NEWS四类别新闻分类_agnews依据新闻标题写content-CSDN博客Torchtext下的AG_NEWS数据集进行分类(官方文档代码)-CSDN博客

496,835 条来自 AG 新闻语料库 4 大类别超过 2000 个新闻源的新闻文章,数据集仅仅援用了标题和描述字段。每个类别分别拥有 30,000 个训练样本及 1900 个测试样本。 README: AG's News Topic Classification Dataset Version 3, Updated 09/09/2015 ORIGIN AG is a collection of more than 1 million news articles. News articles have been gathered from more than 2000 news sources by ComeToMyHead in more than 1 year of activity. ComeToMyHead is an academic news search engine which has been running since July, 2004. The dataset is provided by the academic comunity for research purposes in data mining (clustering, classification, etc), information retrieval (ranking, search, etc), xml, data compression, data streaming, and any other non-commercial activity. For more information, please refer to the link http://www.di.unipi.it/~gulli/AG_corpus_of_news_articles.html . The AG's news topic classification dataset is constructed by Xiang Zhang (xiang.zhang@nyu.edu) from the dataset above. It is used as a text classification benchmark in the following paper: Xiang Zhang, Junbo Zhao, Yann LeCun. Character-level Convolutional Networks for Text Classification. Advances in Neural Information Processing Systems 28 (NIPS 2015). DESCRIPTION The AG's news topic classification dataset is constructed by choosing 4 largest classes from the original corpus. Each class contains 30,000 training samples and 1,900 testing samples. The total number of training samples is 120,000 and testing 7,600. The file classes.txt contains a list of classes corresponding to each label. The files train.csv and test.csv contain all the training samples as comma-sparated values. There are 3 columns in them, corresponding to class index (1 to 4), title and description. The title and description are escaped using double quotes ("), and any internal double quote is escaped by 2 double quotes (""). New lines are escaped by a backslash followed with an "n" character, that is "\n".
AG's News Topic Classification Dataset Version 3, Updated 09/09/2015 ORIGIN AG is a collection of more than 1 million news articles. News articles have been gathered from more than 2000 news sources by ComeToMyHead in more than 1 year of activity. ComeToMyHead is an academic news search engine which has been running since July, 2004. The dataset is provided by the academic comunity for research purposes in data mining (clustering, classification, etc), information retrieval (ranking, search, etc), xml, data compression, data streaming, and any other non-commercial activity. For more information, please refer to the link http://www.di.unipi.it/~gulli/AG_corpus_of_news_articles.html . The AG's news topic classification dataset is constructed by Xiang Zhang (xiang.zhang@nyu.edu) from the dataset above. It is used as a text classification benchmark in the following paper: Xiang Zhang, Junbo Zhao, Yann LeCun. Character-level Convolutional Networks for Text Classification. Advances in Neural Information Processing Systems 28 (NIPS 2015). DESCRIPTION The AG's news topic classification dataset is constructed by choosing 4 largest classes from the original corpus. Each class contains 30,000 training samples and 1,900 testing samples. The total number of training samples is 120,000 and testing 7,600. The file classes.txt contains a list of classes corresponding to each label. The files train.csv and test.csv contain all the training samples as comma-sparated values. There are 3 columns in them, corresponding to class index (1 to 4), title and description. The title and description are escaped using double quotes ("), and any internal double quote is escaped by 2 double quotes (""). New lines are escaped by a backslash followed with an "n" character, that is "\n".
【资源说明】 1、该资源包括项目的全部源码,下载可以直接使用! 2、本项目适合作为计算机、数学、电子信息等专业的课程设计、期末大作业和毕设项目,作为参考资料学习借鉴。 3、本资源作为“参考资料”如果需要实现其他功能,需要能看懂代码,并且热爱钻研,自行调试。 基于textCNN卷积神经网络的英文新闻数据集分类(AG_news)算法源码.zip 基于textCNN卷积神经网络的英文新闻数据集分类(AG_news)算法源码.zip 基于textCNN卷积神经网络的英文新闻数据集分类(AG_news)算法源码.zip 基于textCNN卷积神经网络的英文新闻数据集分类(AG_news)算法源码.zip 基于textCNN卷积神经网络的英文新闻数据集分类(AG_news)算法源码.zip 基于textCNN卷积神经网络的英文新闻数据集分类(AG_news)算法源码.zip 基于textCNN卷积神经网络的英文新闻数据集分类(AG_news)算法源码.zip 基于textCNN卷积神经网络的英文新闻数据集分类(AG_news)算法源码.zip 基于textCNN卷积神经网络的英文新闻数据集分类(AG_news)算法源码.zip 基于textCNN卷积神经网络的英文新闻数据集分类(AG_news)算法源码.zip 基于textCNN卷积神经网络的英文新闻数据集分类(AG_news)算法源码.zip 基于textCNN卷积神经网络的英文新闻数据集分类(AG_news)算法源码.zip 基于textCNN卷积神经网络的英文新闻数据集分类(AG_news)算法源码.zip 基于textCNN卷积神经网络的英文新闻数据集分类(AG_news)算法源码.zip 基于textCNN卷积神经网络的英文新闻数据集分类(AG_news)算法源码.zip 基于textCNN卷积神经网络的英文新闻数据集分类(AG_news)算法源码.zip 基于textCNN卷积神经网络的英文新闻数据集分类(AG_news)算法源码.zip 基于textCNN卷积神经网络的英文新闻数据集分类(AG_news)算法源码.zip 基于textCNN卷积神经网络的英文新闻数据集分类(AG_news)算法源码.zip 基于textCNN卷积神经网络的英文新闻数据集分类(AG_news)算法源码.zip 基于textCNN卷积神经网络的英文新闻数据集分类(AG_news)算法源码.zip 基于textCNN卷积神经网络的英文新闻数据集分类(AG_news)算法源码.zip 基于textCNN卷积神经网络的英文新闻数据集分类(AG_news)算法源码.zip 基于textCNN卷积神经网络的英文新闻数据集分类(AG_news)算法源码.zip 基于textCNN卷积神经网络的英文新闻数据集分类(AG_news)算法源码.zip 基于textCNN卷积神经网络的英文新闻数据集分类(AG_news)算法源码.zip 基于textCNN卷积神经网络的英文新闻数据集分类(AG_news)算法源码.zip 基于textCNN卷积神经网络的英文新闻数据集分类(AG_news)算法源码.zip 基于textCNN卷积神经网络的英文新闻数据集分类(AG_news)算法源码.zip 基于textCNN卷积神经网络的英文新闻数据集分类(AG_news)算法源码.zip 基于textCNN卷积神经网络的英文新闻数据集分类(AG_news)算法源码.zip 基于textCNN卷积神经网络的英文新闻数据集分类(AG_news)算法源码.zip 基于textCNN卷积神经网络的英文新闻数据集分类(AG_news)算法源码.zip 基于textCNN卷积神经网络的英文新闻数据集分类(AG_news)算法源码.zip 基于textCNN卷积神经网络的英文新闻数据集分类(AG_news)算法源码.zip 基于textCNN卷积神经网络的英文新闻数据集分类(AG_news)算法源码.zip 基于textCNN卷积神经网络的英文新闻数据集分类(AG_news)算法源码.zip 基于textCNN卷积神经网络的英文新闻数据集分类(AG_news)算法源码.zip 基于textCNN卷积神经网络的英文新闻数据集分类(AG_news)算法源码.zip 基于textCNN卷积神经网络的英文新闻数据集分类(AG_news)算法源码.zip 基于textCNN卷积神经网络的英文新闻数据集分类(AG_news)算法源码.zip 基于textCNN卷积神经网络的英文新闻数据集分类(AG_news)算法源码.zip 基于textCNN卷积神经网络的英文新闻数据集分类(AG_news)算法源码.zip 基于textCNN卷积神经网络的英文新闻数
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值