- 🍨 本文为🔗365天深度学习训练营中的学习记录博客
- 🍖 原作者:K同学啊|接辅导、项目定制
目录
一、课题背景和开发环境
📌第N1周:Pytorch文本分类入门📌
- Python 3.8.12
- pytorch==1.8.1+cu111
- torchtext==0.9.1
- portalocker==2.7.0
二、环境安装
这是一个使用PyTorch实现的简单文本分类实战案例。在这个例子中,我们将使用AG News数据集进行文本分类。
AG News(AG’s News Topic Classification Dataset)是一个广泛用于文本分类任务的数据集,尤其是在新闻领域。该数据集是由AG’s Corpus of News Articles收集整理而来,包含了四个主要的类别:世界、体育、商业和科技。
首先,确保已经安装了 torchtext
与 portalocker
库
PyTorch version | torchtext version | Supported Python version |
---|---|---|
nightly build | main | >=3.8, <=3.11 |
1.14.0 | 0.15.0 | >=3.8, <=3.11 |
1.13.0 | 0.14.0 | >=3.7, <=3.10 |
1.12.0 | 0.13.0 | >=3.7, <=3.10 |
1.11.0 | 0.12.0 | >=3.6, <=3.9 |
1.10.0 | 0.11.0 | >=3.6, <=3.9 |
1.9.1 | 0.10.1 | >=3.6, <=3.9 |
1.9 | 0.10 | >=3.6, <=3.9 |
1.8.1 | 0.9.1 | >=3.6, <=3.9 |
1.8 | 0.9 | >=3.6, <=3.9 |
1.7.1 | 0.8.1 | >=3.6, <=3.9 |
1.7 | 0.8 | >=3.6, <=3.8 |
1.6 | 0.7 | >=3.6, <=3.8 |
1.5 | 0.6 | >=3.5, <=3.8 |
1.4 | 0.5 | 2.7, >=3.5, <=3.8 |
0.4 and below | 0.2.3 | 2.7, >=3.5, <=3.8 |
三、文本分类
1. 加载数据
import os
import sys
import PIL
from PIL import Image
import time
import copy
import random
import pathlib
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchtext.datasets import AG_NEWS
import torchvision
from torchinfo import summary
import torchsummary
import matplotlib.pyplot as plt
import numpy as np
import warnings
''' 下载或读取AG News数据集中的训练集与测试集 '''
def getDataset(root, dataset):
if not os.path.exists(root) or not os.path.isdir(root):
os.makedirs(root)
if not os.path.exists(dataset) or not os.path.isdir(dataset):
print('Downloading dataset...\n')
# 下载AG News数据集 直接运行会报网络错误 无法下载
train_ds, test_ds = AG_NEWS(root=root, split=("train", "test"))
else:
print('Dataset already downloaded, reading...\n')
# 读取本地AG News数据集 手动下载了train.csv和test.csv后可从本地加载数据
train_ds, test_ds = AG_NEWS(root=dataset, split=("train", "test"))
#print("Train:", next(train_ds), len(list(train_ds))+1)
#print("Test :", next(test_ds), len(list(test_ds))+1)
return train_ds, test_ds
''' 设置GPU '''
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using {} device\n".format(device))
''' 加载数据 '''
root = './data/'
data_dir = os.path.join(root, 'AG_NEWS.data')
train_ds, test_ds = getDataset(root, data_dir)
Using cuda device
Dataset already downloaded, reading...
Train: (3, "Wall St. Bears Claw Back Into the Black (Reuters) Reuters - Short-sellers, Wall Street's dwindling\\band of ultra-cynics, are seeing green again.") 120000
Test : (3, "Fears for T N pension after talks Unions representing workers at Turner Newall say they are 'disappointed' after talks with stricken parent firm Federal Mogul.") 7600
2. 构建词典
''' 构建词典 '''
def buildDict(train_ds):
tokenizer = get_tokenizer('basic_english') # 返回分词器函数
def yield_tokens(data_iter):
for _, text in data_iter:
yield tokenizer(text)
vocab = build_vocab_from_iterator(yield_tokens(train_ds))
text_pipeline = lambda x: vocab.lookup_indices(tokenizer(x))
label_pipeline = lambda x: int(x)
#print(vocab.UNK, vocab._default_unk_index())# 打印默认索引,如果找不到单词,则会选择默认索引
#print(vocab.lookup_indices(['here', 'is', 'an', 'example']))
#print(text_pipeline('here is the an example'))
#print(label_pipeline('10'))
return vocab, text_pipeline, label_pipeline
# 构建词典
text_pipeline, label_pipeline = buildDict(train_ds)
120001lines [00:04, 27817.88lines/s]
<unk> 0
[471, 22, 31, 5177]
[471, 22, 3, 31, 5177]
10
3. 生成数据批次和迭代器
''' 加载数据,并设置batch_size '''
def loadData(train_ds, test_ds, batch_size=8, device='cpu'):
# 构建词典
vocab, text_pipeline, label_pipeline = buildDict(train_ds)
# 生成数据批次和迭代器
def collate_batch(batch):
label_list, text_list, offsets = [