NLP新手入门-第N1周:Pytorch文本分类入门

本博客记录了使用PyTorch进行文本分类的步骤,包括环境配置、加载AG News数据集、构建词典、定义模型、训练与评估。适合NLP初学者。
摘要由CSDN通过智能技术生成

一、课题背景和开发环境

📌第N1周:Pytorch文本分类入门📌

  • Python 3.8.12
  • pytorch==1.8.1+cu111
  • torchtext==0.9.1
  • portalocker==2.7.0

二、环境安装

这是一个使用PyTorch实现的简单文本分类实战案例。在这个例子中,我们将使用AG News数据集进行文本分类。

AG News(AG’s News Topic Classification Dataset)是一个广泛用于文本分类任务的数据集,尤其是在新闻领域。该数据集是由AG’s Corpus of News Articles收集整理而来,包含了四个主要的类别:世界、体育、商业和科技。

首先,确保已经安装了 torchtextportalocker

torchtext 安装版本参考

PyTorch version torchtext version Supported Python version
nightly build main >=3.8, <=3.11
1.14.0 0.15.0 >=3.8, <=3.11
1.13.0 0.14.0 >=3.7, <=3.10
1.12.0 0.13.0 >=3.7, <=3.10
1.11.0 0.12.0 >=3.6, <=3.9
1.10.0 0.11.0 >=3.6, <=3.9
1.9.1 0.10.1 >=3.6, <=3.9
1.9 0.10 >=3.6, <=3.9
1.8.1 0.9.1 >=3.6, <=3.9
1.8 0.9 >=3.6, <=3.9
1.7.1 0.8.1 >=3.6, <=3.9
1.7 0.8 >=3.6, <=3.8
1.6 0.7 >=3.6, <=3.8
1.5 0.6 >=3.5, <=3.8
1.4 0.5 2.7, >=3.5, <=3.8
0.4 and below 0.2.3 2.7, >=3.5, <=3.8

三、文本分类

1. 加载数据

import os
import sys
import PIL
from PIL import Image
import time
import copy
import random
import pathlib
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchtext.datasets import AG_NEWS
import torchvision
from torchinfo import summary
import torchsummary
import matplotlib.pyplot as plt
import numpy as np
import warnings


''' 下载或读取AG News数据集中的训练集与测试集 '''
def getDataset(root, dataset):
    if not os.path.exists(root) or not os.path.isdir(root):
        os.makedirs(root)
    if not os.path.exists(dataset) or not os.path.isdir(dataset):
        print('Downloading dataset...\n')
        # 下载AG News数据集 直接运行会报网络错误 无法下载  
        train_ds, test_ds = AG_NEWS(root=root, split=("train", "test"))
    else:
        print('Dataset already downloaded, reading...\n')
        # 读取本地AG News数据集 手动下载了train.csv和test.csv后可从本地加载数据
        train_ds, test_ds = AG_NEWS(root=dataset, split=("train", "test"))
    #print("Train:", next(train_ds), len(list(train_ds))+1)
    #print("Test :", next(test_ds), len(list(test_ds))+1)
    return train_ds, test_ds


''' 设置GPU '''
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using {} device\n".format(device))
''' 加载数据 '''
root = './data/'
data_dir = os.path.join(root, 'AG_NEWS.data')
train_ds, test_ds = getDataset(root, data_dir)
Using cuda device

Dataset already downloaded, reading...

Train: (3, "Wall St. Bears Claw Back Into the Black (Reuters) Reuters - Short-sellers, Wall Street's dwindling\\band of ultra-cynics, are seeing green again.") 120000
Test : (3, "Fears for T N pension after talks Unions representing workers at Turner   Newall say they are 'disappointed' after talks with stricken parent firm Federal Mogul.") 7600

2. 构建词典

''' 构建词典 '''
def buildDict(train_ds):
    tokenizer  = get_tokenizer('basic_english') # 返回分词器函数
    def yield_tokens(data_iter):
        for _, text in data_iter:
            yield tokenizer(text)
    vocab = build_vocab_from_iterator(yield_tokens(train_ds))
    text_pipeline  = lambda x: vocab.lookup_indices(tokenizer(x))
    label_pipeline = lambda x: int(x)
    #print(vocab.UNK, vocab._default_unk_index())# 打印默认索引,如果找不到单词,则会选择默认索引
    #print(vocab.lookup_indices(['here', 'is', 'an', 'example']))
    #print(text_pipeline('here is the an example'))
    #print(label_pipeline('10'))
    return vocab, text_pipeline, label_pipeline


# 构建词典
text_pipeline, label_pipeline = buildDict(train_ds)
120001lines [00:04, 27817.88lines/s]
<unk> 0
[471, 22, 31, 5177]
[471, 22, 3, 31, 5177]
10

3. 生成数据批次和迭代器

''' 加载数据,并设置batch_size '''
def loadData(train_ds, test_ds, batch_size=8, device='cpu'):
    # 构建词典
    vocab, text_pipeline, label_pipeline = buildDict(train_ds)
    # 生成数据批次和迭代器
    def collate_batch(batch):
        label_list, text_list, offsets = [
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值