第N3周:Pytorch文本分类入门

这里借用K同学的一张图片大致说明本次任务流程。
在这里插入图片描述

1.本次所用AG News数据集介绍

AG News数据集是一个用于文本分类任务的广泛使用的数据集,包含了来自AG新闻网站的四个类别的新闻文章。这些类别分别是:World (世界), Sports (体育), Business (商业)和 Sci/Tech (科学/技术)。每个类别都包含约30,000篇新闻文章,总共有约120,000篇新闻文章。

AG News数据集是一个用于自然语言处理和机器学习的常用基准数据集,可以用于测试文本分类算法的性能。这个数据集具有以下特点:

  1. 大规模:AG News数据集包含大量的新闻文章,可以用于训练深度学习模型。
  2. 多样性:数据集中包含不同主题的新闻文章,可以用于测试模型在不同类别上的分类能力。
  3. 易于使用:数据集已经被广泛使用,有很多开源项目和教程可以帮助用户开始使用。

AG News数据集可以用于多种自然语言处理任务,例如文本分类、情感分析和主题识别等。

2.TextClassificationModel架构

  • TextClassificationModel类继承自nn.Module,其中包含了__init__方法用来初始化模型的各个组件,init_weights方法用来初始化权重,forward方法定义了数据在模型中的流动方式。
  • 模型包含了一个词嵌入层(embedding)和一个全连接层(fc)。
  • init_weights方法用来对模型的权重进行初始化。
  • forward方法接受输入的文本序列(text)和偏移量(offsets),通过词嵌入层得到嵌入表示,然后通过全连接层进行分类预测。
  • num_class表示分类的类别数量,vocab_size表示词典大小,em_size表示词嵌入的维度。
  • 最后,创建了一个model对象,并将其移动到指定的设备(device)上。

3.代码

import  torch
import torch.nn as nn
import torchvision
from torchvision import transforms, datasets
import os,PIL,pathlib,warnings
 
 
 
warnings.filterwarnings("ignore")
#win10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
from torchtext.datasets import AG_NEWS
train_iter = AG_NEWS(split='train')#加载 AG News 数据集
 
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
 
#返回分词器
tokenizer = get_tokenizer('basic_english')
 
def yield_tokens(data_iter):
    for _, text in data_iter:
        yield tokenizer(text)
 
vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"])#设置默认索引
print(vocab(['here', 'is', 'an', 'example']))
 
text_pipeline = lambda x: vocab(tokenizer(x))
label_pipeline = lambda x: int(x) - 1
print(text_pipeline('here is an example '))
print(label_pipeline('10'))
 
 
from torch.utils.data import DataLoader
 
def collate_batch(batch):
    label_list,text_list,offsets =[],[],[0]
    for(_label,_text)in batch:
        #标签列表
        label_list.append(label_pipeline(_label))
        #文本列表
        processed_text =torch.tensor(text_pipeline(_text),dtype=torch.int64)
        text_list.append(processed_text)
         #偏移量,即语句的总词汇量
        offsets.append(processed_text.size(0))
    label_list =torch.tensor(label_list,dtype=torch.int64)
    text_list=torch.cat(text_list)
    offsets=torch.tensor(offsets[:-1]).cumsum(dim=0)
    #返回维度dim中输入元素的累计和
    return label_list.to(device),text_list.to(device),offsets.to(device)
#数据加载器
dataloader =DataLoader(train_iter,batch_size=8,shuffle   =False,collate_fn=collate_batch)
 
 
 
from torch import nn
class TextClassificationModel(nn.Module):
 
    def __init__(self,vocab_size,embed_dim,num_class):
        super(TextClassificationModel,self).__init__()
        self.embedding =nn.EmbeddingBag(vocab_size,#词典大小
 
                                        embed_dim,#嵌入的维度
 
                                        sparse=False)#
        self.fc =nn.Linear(embed_dim,num_class)
        self.init_weights()
    def init_weights(self):
        initrange =0.5
        self.embedding.weight.data.uniform_(-initrange,initrange)
        self.fc.weight.data.uniform_(-initrange,initrange)
        self.fc.bias.data.zero_()
 
    def forward(self,text,offsets):
        embedded =self.embedding(text,offsets)
        return self.fc(embedded)
 
num_class = len(set([label for(label,text)in train_iter]))
vocab_size = len(vocab)
em_size = 64
model = TextClassificationModel(vocab_size,em_size,num_class).to(device)
 
import time
def train(dataloader):
    model.train()  #切换为训练模式
    total_acc,train_loss,total_count =0,0,0
    log_interval =500
    start_time   =time.time()
 
    for idx,(label,text,offsets) in enumerate(dataloader):
        predicted_label =model(text,offsets)
        optimizer.zero_grad()#grad属性归零
        loss =criterion(predicted_label,label)#计算网络输出和真实值之间的差距,labe1为真实值
        loss.backward()#反向传播
        optimizer.step()  #每一步自动更新
        #记录acc与loss
        total_acc   +=(predicted_label.argmax(1)==label).sum().item()
        train_loss  +=loss.item()
        total_count +=label.size(0)
        if idx %log_interval ==0 and idx >0:
            elapsed =time.time()-start_time
            print('|epoch {:1d}|{:4d}/{:4d}batches'
                  '|train_acc {:4.3f}train_loss {:4.5f}'.format(epoch,idx,len(dataloader),total_acc/total_count,train_loss/total_count))
            total_acc,train_loss,total_count =0,0,0
            start_time =time.time()
 
def evaluate(dataloader):
    model.eval()  #切换为测试模式
    total_acc,train_loss,total_count =0,0,0
 
    with torch.no_grad():
        for idx,(label,text,offsets)in enumerate(dataloader):
            predicted_label =model(text,offsets)
 
            loss = criterion(predicted_label,label)  #计算loss值#记录测试数据
            total_acc   +=(predicted_label.argmax(1)==label).sum().item()
            train_loss  +=loss.item()
            total_count +=label.size(0)
 
    return total_acc/total_count,train_loss/total_count
 
 
 
from torch.utils.data.dataset import random_split
from torchtext.data.functional import to_map_style_dataset
#超参数
EPOCHS=10 #epoch
LR=5  #学习率
BATCH_SIZE=64 #batch size for training
criterion =torch.nn.CrossEntropyLoss()
optimizer =torch.optim.SGD(model.parameters(),lr=LR)
scheduler =torch.optim.lr_scheduler.StepLR(optimizer,1.0,gamma=0.1)
total_accu =None
 
train_iter,test_iter =AG_NEWS()#加载数据
train_dataset =to_map_style_dataset(train_iter)
test_dataset =to_map_style_dataset(test_iter)
num_train=int(len(train_dataset)*0.95)
 
split_train_,split_valid_=random_split(train_dataset,
                                       [num_train,len(train_dataset)-num_train])
train_dataloader =DataLoader(split_train_,batch_size=BATCH_SIZE,
                             shuffle=True,collate_fn=collate_batch)
valid_dataloader =DataLoader(split_valid_,batch_size=BATCH_SIZE,
                             shuffle=True,collate_fn=collate_batch)
test_dataloader=DataLoader(test_dataset,batch_size=BATCH_SIZE,
                           shuffle=True,collate_fn=collate_batch)
 
for epoch in range(1,EPOCHS +1):
    epoch_start_time =time.time()
    train(train_dataloader)
    val_acc,val_loss =evaluate(valid_dataloader)
 
    if total_accu is not None and total_accu >val_acc:
        scheduler.step()
    else:
        total_accu =val_acc
    print('-'*69)
    print('|epoch {:1d}|time:{:4.2f}s|'
            'valid_acc {:4.3f}valid_loss {:4.3f}'.format(epoch,
            time.time()-epoch_start_time,val_acc,val_loss))
    print('-'*69)
 
 
print('Checking the results of test dataset.')
test_acc,test_loss =evaluate(test_dataloader)
print('test accuracy {:8.3f}'.format(test_acc))
E:\BaiduNetdiskDownload\pythonProject_PyTorch\venv\Scripts\python.exe E:\BaiduNetdiskDownload\pythonProject_PyTorch\PytorchText.py 
[475, 21, 30, 5297]
[475, 21, 30, 5297]
9
|epoch 1| 500/1782batches|train_acc 0.721train_loss 0.01010
|epoch 1|1000/1782batches|train_acc 0.871train_loss 0.00616
|epoch 1|1500/1782batches|train_acc 0.877train_loss 0.00542
---------------------------------------------------------------------
|epoch 1|time:11.81s|valid_acc 0.794valid_loss 0.009
---------------------------------------------------------------------
|epoch 2| 500/1782batches|train_acc 0.903train_loss 0.00451
|epoch 2|1000/1782batches|train_acc 0.906train_loss 0.00442
|epoch 2|1500/1782batches|train_acc 0.906train_loss 0.00436
---------------------------------------------------------------------
|epoch 2|time:11.64s|valid_acc 0.845valid_loss 0.007
---------------------------------------------------------------------
|epoch 3| 500/1782batches|train_acc 0.919train_loss 0.00374
|epoch 3|1000/1782batches|train_acc 0.917train_loss 0.00383
|epoch 3|1500/1782batches|train_acc 0.915train_loss 0.00393
---------------------------------------------------------------------
|epoch 3|time:11.61s|valid_acc 0.905valid_loss 0.004
---------------------------------------------------------------------
|epoch 4| 500/1782batches|train_acc 0.927train_loss 0.00339
|epoch 4|1000/1782batches|train_acc 0.926train_loss 0.00342
|epoch 4|1500/1782batches|train_acc 0.922train_loss 0.00352
---------------------------------------------------------------------
|epoch 4|time:11.62s|valid_acc 0.870valid_loss 0.006
---------------------------------------------------------------------
|epoch 5| 500/1782batches|train_acc 0.942train_loss 0.00276
|epoch 5|1000/1782batches|train_acc 0.945train_loss 0.00268
|epoch 5|1500/1782batches|train_acc 0.945train_loss 0.00266
---------------------------------------------------------------------
|epoch 5|time:11.67s|valid_acc 0.913valid_loss 0.004
---------------------------------------------------------------------
|epoch 6| 500/1782batches|train_acc 0.946train_loss 0.00259
|epoch 6|1000/1782batches|train_acc 0.946train_loss 0.00261
|epoch 6|1500/1782batches|train_acc 0.946train_loss 0.00261
---------------------------------------------------------------------
|epoch 6|time:11.71s|valid_acc 0.914valid_loss 0.004
---------------------------------------------------------------------
|epoch 7| 500/1782batches|train_acc 0.948train_loss 0.00255
|epoch 7|1000/1782batches|train_acc 0.946train_loss 0.00260
|epoch 7|1500/1782batches|train_acc 0.948train_loss 0.00250
---------------------------------------------------------------------
|epoch 7|time:11.68s|valid_acc 0.912valid_loss 0.004
---------------------------------------------------------------------
|epoch 8| 500/1782batches|train_acc 0.948train_loss 0.00252
|epoch 8|1000/1782batches|train_acc 0.948train_loss 0.00249
|epoch 8|1500/1782batches|train_acc 0.950train_loss 0.00244
---------------------------------------------------------------------
|epoch 8|time:11.52s|valid_acc 0.913valid_loss 0.004
---------------------------------------------------------------------
|epoch 9| 500/1782batches|train_acc 0.949train_loss 0.00249
|epoch 9|1000/1782batches|train_acc 0.950train_loss 0.00246
|epoch 9|1500/1782batches|train_acc 0.950train_loss 0.00248
---------------------------------------------------------------------
|epoch 9|time:11.43s|valid_acc 0.922valid_loss 0.004
---------------------------------------------------------------------
|epoch 10| 500/1782batches|train_acc 0.950train_loss 0.00235
|epoch 10|1000/1782batches|train_acc 0.950train_loss 0.00255
|epoch 10|1500/1782batches|train_acc 0.949train_loss 0.00229
---------------------------------------------------------------------
|epoch 10|time:11.92s|valid_acc 0.924valid_loss 0.004
---------------------------------------------------------------------
Checking the results of test dataset.
test accuracy    0.909
 
Process finished with exit code 0
  • 6
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
AG's News Topic Classification Dataset Version 3, Updated 09/09/2015 ORIGIN AG is a collection of more than 1 million news articles. News articles have been gathered from more than 2000 news sources by ComeToMyHead in more than 1 year of activity. ComeToMyHead is an academic news search engine which has been running since July, 2004. The dataset is provided by the academic comunity for research purposes in data mining (clustering, classification, etc), information retrieval (ranking, search, etc), xml, data compression, data streaming, and any other non-commercial activity. For more information, please refer to the link http://www.di.unipi.it/~gulli/AG_corpus_of_news_articles.html . The AG's news topic classification dataset is constructed by Xiang Zhang (xiang.zhang@nyu.edu) from the dataset above. It is used as a text classification benchmark in the following paper: Xiang Zhang, Junbo Zhao, Yann LeCun. Character-level Convolutional Networks for Text Classification. Advances in Neural Information Processing Systems 28 (NIPS 2015). DESCRIPTION The AG's news topic classification dataset is constructed by choosing 4 largest classes from the original corpus. Each class contains 30,000 training samples and 1,900 testing samples. The total number of training samples is 120,000 and testing 7,600. The file classes.txt contains a list of classes corresponding to each label. The files train.csv and test.csv contain all the training samples as comma-sparated values. There are 3 columns in them, corresponding to class index (1 to 4), title and description. The title and description are escaped using double quotes ("), and any internal double quote is escaped by 2 double quotes (""). New lines are escaped by a backslash followed with an "n" character, that is "\n".

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值