Pytorch文本分类入门

>- **🍨 本文为[🔗365天深度学习训练营](https://mp.weixin.qq.com/s/AtyZUu_j2k_ScNH6e732ow) 中的学习记录博客**
>- **🍖 原作者:[K同学啊 | 接辅导、项目定制](https://mtyjkh.blog.csdn.net/)**
>- **🚀 文章来源:[K同学的学习圈子](https://www.yuque.com/mingtian-fkmxf/zxwb45)**

使用AG News 数据集

AG News(AG's News Topic Classification Dataset)是一个广泛用于文本分类任务的数据集,尤其是在新闻领域。该数据集是由AG's Corpus of News Articles收集整理而来,包含了四个主要的类别:世界、体育、商业和科技。

加载数据 构建词典:

from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

tokenizer = get_tokenizer('basic_english') 

def yield_tokens(data_iter):
    for _, text in data_iter:
        yield tokenizer(text)

vocab = build_vocab_from_iterator(yield_tokens(train_iter),
                                  specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"]) 

build_vocab_from_iterator函数用于根据该yield_tokens函数生成的标记构建词汇表。它采用标记化文本的迭代器 ( yield_tokens(train_iter))

如果在词汇表中找不到某个单词,它将由“ ”标记表示

模型定义:

from torch import nn

class TextClassificationModel(nn.Module):

    def __init__(self, vocab_size, embed_dim, num_class):
        super(TextClassificationModel, self).__init__()
        
        self.embedding = nn.EmbeddingBag(vocab_size, 
                                         embed_dim, 
                                         sparse=False) 
        
        self.fc = nn.Linear(embed_dim, num_class)
        self.init_weights()

    def init_weights(self):
        initrange = 0.5
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.fc.weight.data.uniform_(-initrange, initrange)
        self.fc.bias.data.zero_()

    def forward(self, text, offsets):
        embedded = self.embedding(text, offsets)
        return self.fc(embedded)
  • nn.EmbeddingBag:该层用于有效地表示一包序列的嵌入。它类似于nn.Embedding,但它允许不同长度的序列。
  • vocab_size:词汇量的大小。
  • embed_dim:词嵌入的维数。
  • sparse=False:表示嵌入不稀疏。
  • nn.Linear:将输入映射到输出的全连接(线性)层。
  • embed_dim:输入特征(词嵌入的维度)。
  • num_class:输出特征(分类的类数)。
  • 定义模型的前向传播。
  • 采用text(标记索引) 和offsets作为输入。
  • 使用嵌入层嵌入输入。
  • 将全连接层应用于嵌入输入。

模型训练:

def train(dataloader):
    model.train()  # Switch to training mode
    total_acc, train_loss, total_count = 0, 0, 0
    log_interval = 500
    start_time = time.time()

    for idx, (label, text, offsets) in enumerate(dataloader):
        predicted_label = model(text, offsets)

        optimizer.zero_grad()  # Zero out gradients
        loss = criterion(predicted_label, label)  # Compute the loss between predicted and true labels
        loss.backward()  # Backward pass to compute gradients
        optimizer.step()  # Update weights based on the computed gradients

        # Record accuracy and loss
        total_acc += (predicted_label.argmax(1) == label).sum().item()
        train_loss += loss.item()
        total_count += label.size(0)

        # Print training progress
        if idx % log_interval == 0 and idx > 0:
            elapsed = time.time() - start_time
            print('| epoch {:1d} | {:4d}/{:4d} batches '
                  '| train_acc {:4.3f} train_loss {:4.5f}'.format(epoch, idx, len(dataloader),
                                              total_acc / total_count, train_loss / total_count))
            total_acc, train_loss, total_count = 0, 0, 0
            start_time = time.time()
  • 使用 将模型设置为训练模式model.train()
  • 迭代提供的批次dataloader
  • 计算预测、损失,并执行反向传播步骤来更新模型参数。
  • 记录监测的准确性和损失。
  • 按指定时间间隔打印训练进度。
def evaluate(dataloader):
    model.eval()  # Switch to evaluation mode
    total_acc, train_loss, total_count = 0, 0, 0

    with torch.no_grad():
        for idx, (label, text, offsets) in enumerate(dataloader):
            predicted_label = model(text, offsets)

            loss = criterion(predicted_label, label)  # Compute loss without backpropagation

            # Record accuracy and loss for evaluation
            total_acc += (predicted_label.argmax(1) == label).sum().item()
            train_loss += loss.item()
            total_count += label.size(0)

    return total_acc / total_count, train_loss / total_count
  • 使用 将模型设置为评估模式model.eval()
  • 迭代提供的批次dataloader
  • 计算预测和损失而不执行反向传播。
  • 记录评估的准确性和损失。

可以在训练循环中重复调用train训练数据,然后调用evaluate验证或测试数据来监控模型的性能。

训练结果:

  • 10
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
AG's News Topic Classification Dataset Version 3, Updated 09/09/2015 ORIGIN AG is a collection of more than 1 million news articles. News articles have been gathered from more than 2000 news sources by ComeToMyHead in more than 1 year of activity. ComeToMyHead is an academic news search engine which has been running since July, 2004. The dataset is provided by the academic comunity for research purposes in data mining (clustering, classification, etc), information retrieval (ranking, search, etc), xml, data compression, data streaming, and any other non-commercial activity. For more information, please refer to the link http://www.di.unipi.it/~gulli/AG_corpus_of_news_articles.html . The AG's news topic classification dataset is constructed by Xiang Zhang (xiang.zhang@nyu.edu) from the dataset above. It is used as a text classification benchmark in the following paper: Xiang Zhang, Junbo Zhao, Yann LeCun. Character-level Convolutional Networks for Text Classification. Advances in Neural Information Processing Systems 28 (NIPS 2015). DESCRIPTION The AG's news topic classification dataset is constructed by choosing 4 largest classes from the original corpus. Each class contains 30,000 training samples and 1,900 testing samples. The total number of training samples is 120,000 and testing 7,600. The file classes.txt contains a list of classes corresponding to each label. The files train.csv and test.csv contain all the training samples as comma-sparated values. There are 3 columns in them, corresponding to class index (1 to 4), title and description. The title and description are escaped using double quotes ("), and any internal double quote is escaped by 2 double quotes (""). New lines are escaped by a backslash followed with an "n" character, that is "\n".
496,835 条来自 AG 新闻语料库 4 大类别超过 2000 个新闻源的新闻文章,数据集仅仅援用了标题和描述字段。每个类别分别拥有 30,000 个训练样本及 1900 个测试样本。 README: AG's News Topic Classification Dataset Version 3, Updated 09/09/2015 ORIGIN AG is a collection of more than 1 million news articles. News articles have been gathered from more than 2000 news sources by ComeToMyHead in more than 1 year of activity. ComeToMyHead is an academic news search engine which has been running since July, 2004. The dataset is provided by the academic comunity for research purposes in data mining (clustering, classification, etc), information retrieval (ranking, search, etc), xml, data compression, data streaming, and any other non-commercial activity. For more information, please refer to the link http://www.di.unipi.it/~gulli/AG_corpus_of_news_articles.html . The AG's news topic classification dataset is constructed by Xiang Zhang (xiang.zhang@nyu.edu) from the dataset above. It is used as a text classification benchmark in the following paper: Xiang Zhang, Junbo Zhao, Yann LeCun. Character-level Convolutional Networks for Text Classification. Advances in Neural Information Processing Systems 28 (NIPS 2015). DESCRIPTION The AG's news topic classification dataset is constructed by choosing 4 largest classes from the original corpus. Each class contains 30,000 training samples and 1,900 testing samples. The total number of training samples is 120,000 and testing 7,600. The file classes.txt contains a list of classes corresponding to each label. The files train.csv and test.csv contain all the training samples as comma-sparated values. There are 3 columns in them, corresponding to class index (1 to 4), title and description. The title and description are escaped using double quotes ("), and any internal double quote is escaped by 2 double quotes (""). New lines are escaped by a backslash followed with an "n" character, that is "\n".
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值