前言
前面学习了相关自然语言编码,这周进行相关实战
导入依赖库和设置设备
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms, datasets
import os, PIL, pathlib, warnings
warnings.filterwarnings("ignore") # 忽略警告
# win10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
这段代码导入了必要的库并设置了设备(GPU或CPU)。
数据预处理和词汇表构建
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torchtext.datasets import AG_NEWS
train_iter = AG_NEWS(split='train')
tokenizer = get_tokenizer('basic_english') # 返回分词器函数
def yield_tokens(data_iter):
for _, text in data_iter:
yield tokenizer(text)
vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"]) # 设置默认索引,如果找不到单词,则会选择默认索引
这里使用torchtext
库加载AG_NEWS数据集,定义了一个分词器并构建了词汇表。
数据处理管道
text_pipeline = lambda x: vocab(tokenizer(x))
label_pipeline = lambda x: int(x) - 1
text_pipeline('here is the an example')
定义了两个数据处理管道:text_pipeline
用于将文本转化为词汇表中的索引序列,label_pipeline
用于将标签转化为整数索引。
定义数据加载器
from torch.utils.data import DataLoader
def collate_batch(batch):
label_list, text_list, offsets = [], [], [0]
for (_label, _text) in batch:
label_list.append(label_pipeline(_label))
processed_text = torch.tensor(text_pipeline<