【PyTorch】7 文本分类TorchText实战——AG_NEWS四类别新闻分类

这是官方文本篇的一个教程,原1.4版本Pytorch中文链接1.7版本Pytorch中文链接,原英文文档,介绍了如何使用torchtext中的文本分类数据集,本文是其详细的注解,关于TorchText API的官方英文文档,参考博客

本示例说明了如何使用这些TextClassification数据集之一训练用于分类的监督学习算法

ngrams功能用于捕获有关本地单词顺序的一些部分信息。 在实践中,应用二元语法或三元语法作为单词组比仅仅一个单词提供更多的好处。 一个例子:

"load data with ngrams"
Bi-grams results: "load data", "data with", "with ngrams"
Tri-grams results: "load data with", "data with ngrams"

TextClassification数据集支持 ngrams 方法。 通过将 ngrams 设置为 2,数据集中的示例文本将是一个单字加 bi-grams 字符串的列表

输入以下代码进行安装:

pip install torchtext

原文的这个from torchtext.datasets import text_classification代码是错的,而且text_classification.DATASETS['AG_NEWS']的参数都变了,详见英文手册

1.访问原始数据集迭代器

torchtext 库提供了一些原始数据集迭代器,这些迭代器产生原始文本字符串。例如,AG_NEWS数据集迭代器产生的原始数据是标签和文本的元组

使用此函数时train_data, test_dataset = AG_NEWS(root=path, split=('train', 'test'))会报错:

TimeoutError: [WinError 10060] A connection attempt failed because the connected party did not properly respond after a period of time, or established connection failed because connected host has failed to respond

这里直接打开url进行下载:

URL = {
    'train': "https://raw.githubusercontent.com/mhjabreel/CharCnn_Keras/master/data/ag_news_csv/train.csv",
    'test': "https://raw.githubusercontent.com/mhjabreel/CharCnn_Keras/master/data/ag_news_csv/test.csv",
}
from torchtext.datasets import AG_NEWS
path = '... your path\\AG_NEWS.data'

train_data, test_dataset = AG_NEWS(root=path, split=('train', 'test'))

print(next(train_data))
print(next(train_data))
(3, "Wall St. Bears Claw Back Into the Black (Reuters) Reuters - Short-sellers, Wall Street's dwindling\\band of ultra-cynics, are seeing green again.")
(3, 'Carlyle Looks Toward Commercial Aerospace (Reuters) Reuters - Private investment firm Carlyle Group,\\which has a reputation for making well-timed and occasionally\\controversial plays in the defense industry, has quietly placed\\its bets on another part of the market.')

2. 准备数据处理管道

我们已经重新审视了torchtext库中最基本的组件,包括vocab、单词向量、tokenizer。这些都是原始文本字符串的基本数据处理构件

这里是一个典型的NLP数据处理的例子,使用tokenizer和词汇。第一步是用原始训练数据集建立一个词汇表,用户可以通过在Vocab类的构造函数中设置参数来拥有一个自定义的词汇表。用户可以通过在Vocab类的构造函数中设置参数来拥有一个自定义的词汇表。例如,要包含的令牌的最小频率min_freq

对于函数lambda,此表达式是一种匿名函数,对应python中的自定义函数def

词汇块将一个tokens列表转换成整数

[vocab[token] for token in ['here', 'is', 'an', 'example']]
>>> [476, 22, 31, 5298]

用标记器和词汇准备文本处理管道。文本和标签流水线将用于处理来自数据集迭代器的原始数据字符串

文本流水线根据词汇表中定义的查找表将文本字符串转换为整数列表。标签流水线将标签转换为整数。例如:

text_pipeline('here is the an example')
>>> [475, 21, 2, 30, 5286]
label_pipeline('10')
>>> 9

3. 生成数据批次和迭代器

torch.utils.data.DataLoader 推荐给 PyTorch 用户使用(教程在这里)。它适用于实现 getitem()len()协议的地图式数据集,并表示从索引/键到数据样本的映射。它也适用于shuffle argumnent为False的可迭代数据集

在发送至模型之前, collate_fn 函数对 DataLoader 中生成的一批样本进行处理。collate_fn的输入是DataLoader中批量大小的数据, collate_fn根据之前声明的数据处理管道对它们进行处理。这里要注意,一定要将 collate_fn 声明为顶层 def,这样才能保证该函数在每个 worker 中都能使用

在这个例子中,原始数据批输入中的文本条目被打包成一个列表,并作为一个单一的张量来连接nn.EmbeddingBag的输入。偏移量是一个定界符的张量,用于表示文本张量中各个序列的起始索引。Label是一个张量,保存了indidividual文本条目的标签

关于torch.cumsum()函数的用法:

x = torch.arange(0, 6).view(2, 3)
print(x)
print(x.cumsum(dim=0))
print(x.cumsum(dim=1))
tensor([[0, 1, 2],
        [3, 4, 5]])
tensor([[0, 1, 2],
        [3, 5, 7]])
tensor([[ 0,  1,  3],
        [ 3,  7, 12]])

个人理解collate_fn是从样本列表中过来了一个batch的数据,经过映射函数,形成一个tensor

4. 定义模型

该模型由nn.EmbeddingBag层加上一个线性层组成,以达到分类的目的。nn.EmbeddingBag默认模式为 “mean”,计算一个 "袋 "的嵌入物的平均值。虽然这里的文本条目有不同的长度,但由于文本长度是以偏移量保存的,所以nn.EmbeddingBag模块在这里不需要填充

另外,由于nn.EmbeddingBag会动态累积嵌入中的平均值,因此nn.EmbeddingBag可以提高性能和存储效率,以处理张量序列

在这里插入图片描述
关于EmbeddingBag()函数,官方文档,参考此,参数只多了一个:mode,来看这个参数的取值有三种,对应三种操作:"sum"表示普通embedding后接torch.sum(dim=0),"mean"相当于后接torch.mean(dim=0),"max"相当于后接torch.max(dim=0)

此网络输入输出的例子:

>>> # an Embedding module containing 10 tensors of size 3
>>> embedding_sum = nn.EmbeddingBag(10, 3, mode='sum')
>>> # a batch of 2 samples of 4 indices each
>>> input = torch.LongTensor([1,2,4,5,4,3,2,9])
>>> offsets = torch.LongTensor([0,4])
>>> embedding_sum(input, offsets)
tensor([[-0.8861, -5.4350, -0.0523],
        [ 1.1306, -2.5798, -1.0044]])

5. 初始化一个实例

AG_NEWS数据集有四个标签,因此类的数量是四个:

1 : World
2 : Sports
3 : Business
4 : Sci/Tec

我们建立一个嵌入维度为64的模型,vocab大小等于词汇实例的长度,类的数量等于标签的数量4

6. 定义训练模型和评估结果的函数

关于调整学习率,官方文档,函数:torch.optim.lr_scheduler提供了几种方法来调整基于epochs的学习率

torch.optim.lr_scheduler.StepLR每隔一个step_size epochs,将每个参数组的学习率按gamma衰减。请注意,这种衰减可以与其他来自这个调度器外部的学习率变化同时发生。当last_epoch=-1时,设置初始lr为lr

关于torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)函数,作用是剪切参数迭代的梯度法线,官方文档,法线是在所有梯度上一起计算的,就像它们被连成一个向量一样。梯度是就地修改的,即:梯度剪切,规定了最大不能超过的max_norm

对于每一个batch预测的predited_label,是一个64*4的tensor,对于每一个label,是一个64的一维的tensor

tensor([[ 0.4427,  0.0830,  0.0109,  0.1273],
        [ 0.1601,  0.0869, -0.0540,  0.0422],
        ...
tensor([0, 0, 0, 3, 1, 1, 1, 3, 3, 3, 3, 3, 1, 1, 3, 1, 1, 3, 3, 3, 1, 1, 3, 3,
        3, 1, 1, 2, 1, 2, 1, 1, 3, 3, 1, 1, 1, 3, 1, 3, 0, 1, 0, 0, 1, 3, 3, 3,
        2, 3, 1, 3, 3, 3, 1, 3, 3, 1, 1, 2, 0, 2, 1, 3])

之前我们用的是.topk()函数,这里了解一下.argmax(1)函数:

print(predited_label.argmax(1) == label)
tensor([False,  True,  True,  True, False,  True,  True,  True,  True,  True,
         True, False,  True,  True,  True, False,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True, False,  True, False,  True,  True,  True,  True, False,  True,
         True,  True, False,  True, False,  True,  True, False,  True,  True,
         True, False, False,  True,  True, False,  True, False, False,  True,
        False,  True,  True,  True])

执行以下代码输出就是一个常数:

(predited_label.argmax(1) == label).sum().item()

7. 拆分数据集并运行模型

由于原AG_NEWS没有有效数据集,我们将训练数据集拆分为训练/有效集,拆分比例为0.95(训练)和0.05(有效)。这里我们使用PyTorch核心库中的torch.utils.data.dataset.random_split函数

CrossEntropyLoss准则将nn.LogSoftmax()和nn.NLLLoss()结合在一个类中。它在训练C类的分类问题时非常有用。SGD实现了随机梯度下降法作为优化器。初始学习率设置为5.0。这里使用StepLR通过epochs来调整学习率

打印训练过程:

| epoch   1 |   500/ 1782 batches, accuracy    0.685
| epoch   1 |  1000/ 1782 batches, accuracy    0.852
| epoch   1 |  1500/ 1782 batches, accuracy    0.876
-----------------------------------------------------------
| end of epoch   1 | time: 15.24s | valid accuracy    0.886 
-----------------------------------------------------------
| epoch   2 |   500/ 1782 batches, accuracy    0.896
| epoch   2 |  1000/ 1782 batches, accuracy    0.902
| epoch   2 |  1500/ 1782 batches, accuracy    0.902
-----------------------------------------------------------
| end of epoch   2 | time: 15.20s | valid accuracy    0.899 
-----------------------------------------------------------
| epoch   3 |   500/ 1782 batches, accuracy    0.915
| epoch   3 |  1000/ 1782 batches, accuracy    0.914
| epoch   3 |  1500/ 1782 batches, accuracy    0.915
-----------------------------------------------------------
| end of epoch   3 | time: 15.22s | valid accuracy    0.904 
-----------------------------------------------------------
| epoch   4 |   500/ 1782 batches, accuracy    0.924
| epoch   4 |  1000/ 1782 batches, accuracy    0.924
| epoch   4 |  1500/ 1782 batches, accuracy    0.923
-----------------------------------------------------------
| end of epoch   4 | time: 15.16s | valid accuracy    0.908 
-----------------------------------------------------------
| epoch   5 |   500/ 1782 batches, accuracy    0.930
| epoch   5 |  1000/ 1782 batches, accuracy    0.929
| epoch   5 |  1500/ 1782 batches, accuracy    0.931
-----------------------------------------------------------
| end of epoch   5 | time: 15.21s | valid accuracy    0.900 
-----------------------------------------------------------
| epoch   6 |   500/ 1782 batches, accuracy    0.943
| epoch   6 |  1000/ 1782 batches, accuracy    0.941
| epoch   6 |  1500/ 1782 batches, accuracy    0.944
-----------------------------------------------------------
| end of epoch   6 | time: 15.17s | valid accuracy    0.911 
-----------------------------------------------------------
| epoch   7 |   500/ 1782 batches, accuracy    0.943
| epoch   7 |  1000/ 1782 batches, accuracy    0.945
| epoch   7 |  1500/ 1782 batches, accuracy    0.946
-----------------------------------------------------------
| end of epoch   7 | time: 15.24s | valid accuracy    0.912 
-----------------------------------------------------------
| epoch   8 |   500/ 1782 batches, accuracy    0.945
| epoch   8 |  1000/ 1782 batches, accuracy    0.944
| epoch   8 |  1500/ 1782 batches, accuracy    0.944
-----------------------------------------------------------
| end of epoch   8 | time: 15.20s | valid accuracy    0.913 
-----------------------------------------------------------
| epoch   9 |   500/ 1782 batches, accuracy    0.944
| epoch   9 |  1000/ 1782 batches, accuracy    0.948
| epoch   9 |  1500/ 1782 batches, accuracy    0.946
-----------------------------------------------------------
| end of epoch   9 | time: 15.29s | valid accuracy    0.915 
-----------------------------------------------------------
| epoch  10 |   500/ 1782 batches, accuracy    0.949
| epoch  10 |  1000/ 1782 batches, accuracy    0.945
| epoch  10 |  1500/ 1782 batches, accuracy    0.946
-----------------------------------------------------------
| end of epoch  10 | time: 15.19s | valid accuracy    0.913 
-----------------------------------------------------------
Checking the results of test dataset.
test accuracy    0.908

对于这样一个句子:

"MEMPHIS, Tenn. – Four days ago, Jon Rahm was enduring the season’s worst weather conditions on Sunday at The Open on his way to a closing 75 at Royal Portrush, which considering the wind and the rain was a respectable showing. Thursday’s first round at the WGC-FedEx St. Jude Invitational was another story. With temperatures in the mid-80s and hardly any wind, the Spaniard was 13 strokes better in a flawless round. Thanks to his best putting performance on the PGA Tour, Rahm finished with an 8-under 62 for a three-stroke lead, which was even more impressive considering he’d never played the front nine at TPC Southwind."

输出结果:

This is a Sports news

对于这样一个句子:

'Beijing of Automation, Beijing Institute of Technology'

输出结果:

This is a Sci/Tec news

可以发现分类结果还是比较理想的

8. 全部代码

path = '... your path\\AG_NEWS.data'

import torch
from torchtext.datasets import AG_NEWS

train_iter = AG_NEWS(root=path, split='train')      # 访问原始数据集迭代器

from torchtext.data.utils import get_tokenizer
from collections import Counter
from torchtext.vocab import Vocab

tokenizer = get_tokenizer('basic_english')      # 输入的字符串
counter = Counter()
for (label, line) in train_iter:
    counter.update(tokenizer(line))
vocab = Vocab(counter, min_freq=1)

# 准备数据处理管道
text_pipeline = lambda x: [vocab[token] for token in tokenizer(x)]      # token就是word,vocab[token]就是其对应的数字
label_pipeline = lambda x: int(x) - 1       # 把1、2、3、4 转化为 0、1、2、3 四类

# 生成数据批次和迭代器
from torch.utils.data import DataLoader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def collate_batch(batch):
    label_list, text_list, offsets = [], [], [0]
    for (_label, _text) in batch:
        label_list.append(label_pipeline(_label))
        processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)      # torch.Size([41]), torch.Size([58])...
        text_list.append(processed_text)
        offsets.append(processed_text.size(0))

    label_list = torch.tensor(label_list, dtype=torch.int64)        # torch.Size([64])
    offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)      # torch.Size([64])
    text_list = torch.cat(text_list)        # 若干tensor组成的列表变成一个tensor
    return label_list.to(device), text_list.to(device), offsets.to(device)

# dataloader = DataLoader(train_iter, batch_size=8, shuffle=False, collate_fn=collate_batch)

# import ipdb
from torch import nn
class TextClassificationModel(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_class):
        super(TextClassificationModel, self).__init__()
        self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=True)
        self.fc = nn.Linear(embed_dim, num_class)
        self.init_weights()

    def init_weights(self):
        initrange = 0.5
        self.embedding.weight.data.uniform_(-initrange, initrange)      # 将tensor用从均匀分布中抽样得到的值填充
        self.fc.weight.data.uniform_(-initrange, initrange)
        self.fc.bias.data.zero_()

    def forward(self, text, offsets):
        embedded = self.embedding(text, offsets)        # torch.Size([64, 64])
        output = self.fc(embedded)      # torch.Size([64, 4])
        return output


# num_class = len(set([label for (label, text) in train_iter]))       # 迭代器需要重新开始才能计算...即train_iter = AG_NEWS(root=path, split='train')      # 访问原始数据集迭代器
num_class = 4
vocab_size = len(vocab)
emsize = 64
model = TextClassificationModel(vocab_size, emsize, num_class).to(device)


import time
def train(dataloader):
    model.train()       # 训练模式
    total_acc, total_count = 0, 0
    log_interval = 500
    start_time = time.time()

    for idx, (label, text, offsets) in enumerate(dataloader):
        optimizer.zero_grad()
        predited_label = model(text, offsets)
        loss = criterion(predited_label, label)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)     # 规定了最大不能超过的max_norm
        optimizer.step()
        total_acc += (predited_label.argmax(1) == label).sum().item()
        total_count += label.size(0)
        if idx % log_interval == 0 and idx > 0:
            elapsed = time.time() - start_time
            print('| epoch {:3d} | {:5d}/{:5d} batches, accuracy {:8.3f}'.format(epoch, idx, len(dataloader), total_acc / total_count))
            total_acc, total_count = 0, 0
            start_time = time.time()


def evaluate(dataloader):
    model.eval()
    total_acc, total_count = 0, 0

    with torch.no_grad():
        for idx, (label, text, offsets) in enumerate(dataloader):
            predited_label = model(text, offsets)
            # loss = criterion(predited_label, label)
            total_acc += (predited_label.argmax(1) == label).sum().item()
            total_count += label.size(0)
    return total_acc / total_count


def predict(text, text_pipeline):
    with torch.no_grad():
        text = torch.tensor(text_pipeline(text))
        output = model(text, torch.tensor([0]))
        return output.argmax(1).item() + 1


from torch.utils.data.dataset import random_split

if __name__ == '__main__':
    # 超参数(Hyperparameters)
    # EPOCHS = 10  # epoch
    # LR = 5  # learning rate
    # BATCH_SIZE = 64  # batch size for training
    #
    # criterion = torch.nn.CrossEntropyLoss()
    # optimizer = torch.optim.SGD(model.parameters(), lr=LR)
    # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)
    # total_accu = None
    # train_iter, test_iter = AG_NEWS(root=path)
    # train_dataset = list(train_iter)
    # test_dataset = list(test_iter)
    # num_train = int(len(train_dataset) * 0.95)
    # split_train_, split_valid_ = random_split(train_dataset, [num_train, len(train_dataset) - num_train])
    #
    # train_dataloader = DataLoader(split_train_, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)      # shuffle表示随机打乱
    # valid_dataloader = DataLoader(split_valid_, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
    # test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
    #
    # for epoch in range(1, EPOCHS + 1):
    #     epoch_start_time = time.time()
    #     train(train_dataloader)
    #     accu_val = evaluate(valid_dataloader)
    #     if total_accu is not None and total_accu > accu_val:
    #         scheduler.step()
    #     else:
    #         total_accu = accu_val
    #     print('-' * 59)
    #     print('| end of epoch {:3d} | time: {:5.2f}s | '
    #           'valid accuracy {:8.3f} '.format(epoch, time.time() - epoch_start_time, accu_val))
    #     print('-' * 59)
    #
    #
    # print('Checking the results of test dataset.')
    # accu_test = evaluate(test_dataloader)
    # print('test accuracy {:8.3f}'.format(accu_test))
    #
    # torch.save(model.state_dict(), '... your path\\model_TextClassification.pth')


    # 以下是评估
    model.load_state_dict(torch.load('... your path\\model_TextClassification.pth'))

    ag_news_label = {1: "World",
                     2: "Sports",
                     3: "Business",
                     4: "Sci/Tec"}

    # ex_text_str = "MEMPHIS, Tenn. – Four days ago, Jon Rahm was enduring the season’s worst weather conditions on Sunday at The Open on his way to a closing 75 at Royal Portrush, which considering the wind and the rain was a respectable showing. Thursday’s first round at the WGC-FedEx St. Jude Invitational was another story. With temperatures in the mid-80s and hardly any wind, the Spaniard was 13 strokes better in a flawless round. Thanks to his best putting performance on the PGA Tour, Rahm finished with an 8-under 62 for a three-stroke lead, which was even more impressive considering he’d never played the front nine at TPC Southwind."
    ex_text_str = 'Beijing of Automation, Beijing Institute of Technology'
    # model = model.to("cpu")

    print("This is a %s news" % ag_news_label[predict(ex_text_str, text_pipeline)])

小结

  1. 数据集的获取踩了一些坑,首先是中文教材是错的,没有及时更新,还是得去看英文的;以及下次时github又下不动,用IDM才能完成下载……
  2. 数据通道准备实际上就是英文单词的one-hot模型
  3. 数据批次和迭代器的DataLoader应该很重要,它能把数据转化成流式来处理,避免全部读进来,内存直接爆掉;collate_fn这种将batch变成tensor第一次接触有点难懂
  4. 模型比较简单,就是每个单词embedding之后取个平均来表示一个句子
  5. 训练时有个更新学习率的操作,可以借鉴一下;它做了个验证集感觉也没什么用……

未来工作:

  1. 另外一个TorchText的实验代码复现一下
  2. 学习BERT、Transformer模型,编程实现
  • 18
    点赞
  • 46
    收藏
    觉得还不错? 一键收藏
  • 38
    评论
AG's News Topic Classification Dataset Version 3, Updated 09/09/2015 ORIGIN AG is a collection of more than 1 million news articles. News articles have been gathered from more than 2000 news sources by ComeToMyHead in more than 1 year of activity. ComeToMyHead is an academic news search engine which has been running since July, 2004. The dataset is provided by the academic comunity for research purposes in data mining (clustering, classification, etc), information retrieval (ranking, search, etc), xml, data compression, data streaming, and any other non-commercial activity. For more information, please refer to the link http://www.di.unipi.it/~gulli/AG_corpus_of_news_articles.html . The AG's news topic classification dataset is constructed by Xiang Zhang (xiang.zhang@nyu.edu) from the dataset above. It is used as a text classification benchmark in the following paper: Xiang Zhang, Junbo Zhao, Yann LeCun. Character-level Convolutional Networks for Text Classification. Advances in Neural Information Processing Systems 28 (NIPS 2015). DESCRIPTION The AG's news topic classification dataset is constructed by choosing 4 largest classes from the original corpus. Each class contains 30,000 training samples and 1,900 testing samples. The total number of training samples is 120,000 and testing 7,600. The file classes.txt contains a list of classes corresponding to each label. The files train.csv and test.csv contain all the training samples as comma-sparated values. There are 3 columns in them, corresponding to class index (1 to 4), title and description. The title and description are escaped using double quotes ("), and any internal double quote is escaped by 2 double quotes (""). New lines are escaped by a backslash followed with an "n" character, that is "\n".
496,835 条来自 AG 新闻语料库 4 大类别超过 2000 个新闻源的新闻文章,数据集仅仅援用了标题和描述字段。每个类别分别拥有 30,000 个训练样本及 1900 个测试样本。 README: AG's News Topic Classification Dataset Version 3, Updated 09/09/2015 ORIGIN AG is a collection of more than 1 million news articles. News articles have been gathered from more than 2000 news sources by ComeToMyHead in more than 1 year of activity. ComeToMyHead is an academic news search engine which has been running since July, 2004. The dataset is provided by the academic comunity for research purposes in data mining (clustering, classification, etc), information retrieval (ranking, search, etc), xml, data compression, data streaming, and any other non-commercial activity. For more information, please refer to the link http://www.di.unipi.it/~gulli/AG_corpus_of_news_articles.html . The AG's news topic classification dataset is constructed by Xiang Zhang (xiang.zhang@nyu.edu) from the dataset above. It is used as a text classification benchmark in the following paper: Xiang Zhang, Junbo Zhao, Yann LeCun. Character-level Convolutional Networks for Text Classification. Advances in Neural Information Processing Systems 28 (NIPS 2015). DESCRIPTION The AG's news topic classification dataset is constructed by choosing 4 largest classes from the original corpus. Each class contains 30,000 training samples and 1,900 testing samples. The total number of training samples is 120,000 and testing 7,600. The file classes.txt contains a list of classes corresponding to each label. The files train.csv and test.csv contain all the training samples as comma-sparated values. There are 3 columns in them, corresponding to class index (1 to 4), title and description. The title and description are escaped using double quotes ("), and any internal double quote is escaped by 2 double quotes (""). New lines are escaped by a backslash followed with an "n" character, that is "\n".
PyTorch是一个流行的深度学习框架,它在自然语言处理领域中被广泛应用于新闻文本分类任务。 在PyTorch中进行新闻文本分类,首先需要准备数据集。数据集应包含标注好的新闻文本和对应的分类标签。可以使用Python的数据处理库,例如pandas和numpy,对数据进行加载和预处理。 接下来,在PyTorch中定义一个模型。可以使用预训练的词向量模型(如Word2Vec或GloVe)将文本转换为数值表示。然后,可以建立一个神经网络模型,例如卷积神经网络(CNN)或循环神经网络(RNN),来对文本进行分类PyTorch提供了灵活而强大的工具来定义和训练各种深度学习模型。 在定义好模型之后,需要将数据集划分为训练集、验证集和测试集。将训练集用于模型的参数优化,使用验证集来选择模型的超参数,并使用测试集评估模型的性能。PyTorch提供了数据加载和处理的工具,可以方便地进行数据集划分和批量处理。 在训练过程中,可以使用PyTorch的优化器(如Adam或SGD)来最小化损失函数,从而更新模型的参数。采用合适的损失函数(如交叉熵)可以度量模型的分类性能。 训练完成后,可以使用测试集来评估模型的准确度、精确度、召回率等指标。根据评估结果,可以对模型进行调整和改进,以提高分类的准确性。 最后,可以使用经过训练的模型对新的未标注新闻文本进行分类预测。只需将文本输入到模型中,即可得到对应的分类结果。 总结来说,PyTorch是一个强大的工具,可用于进行新闻文本分类任务。通过准备数据集、定义模型、训练优化和评估模型的过程,可以使用PyTorch新闻文本进行分类,并获得较高的准确性和性能。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 38
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值