迭代器
迭代器迭代器统一了所有不同数据类型的遍历工作,可以任意组装我们想要的数据格式和大小。
python中可迭代的类型 str,list,tuple,dict,set,open()
迭代器特点:
只能向前遍历不能向后
节省内存空间
自定义迭代器必须实现以下两个特殊方法:
iter():返回对象本身作为迭代器
next(): 根据需要从数据源中获取并组装数据
代码实现
实现场景,现在假设我们有一个非常大的.txt文件,每一行是一个文本加一个标签(多分类训练数据集)。我们想将训练文本数据转化为数字,并且按照batch_size的大小送入模型中进行训练。
实现流程:
- 将训练数据集读入内存中,按照格式封装
- 自定义迭代器类,封装需要返回的数据和数据的大小
- 使用迭代器获取数据,将数据送入模型中
train.txt训练数据
信用贷,蚂蚁金服,微粒贷,小象优品,京东金条 14
信用贷,蚂蚁金服,微粒贷,银行渠道部,小象优品 14
信用贷,蚂蚁金服,微粒贷,银行渠道部,银行 14
. . .
config.py配置文件:定义了训练数据集路径,词表,mini-batch等
import torch
class Config(object):
"""
配置参数
"""
def __init__(self):
# 数据集路径
self.train_path = 'train.txt'
# 词表路径
self.vocab_path = "vocab.pkl"
# 设备
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# mini-batch大小
self.batch_size = 32
# 每句话处理成的长度(短填长切)
self.pad_size = 25
build_dataset.py文件的build_dataset函数定义句子切词方法是字符还是词语,load_dataset函数从文本中加载数据同时按照固定格式处理文本。
import os
import pickle as pkl
from tqdm import tqdm
MAX_VOCAB_SIZE = 10000
UNK, PAD = '<UNK>', '<PAD>'
def build_dataset(config, ues_word):
if ues_word:
tokenizer = lambda x: x.split(' ') # 词级别划分以空格隔开
else:
tokenizer = lambda x: [y for y in x] # 字符级别划分
assert os.path.exists(config.vocab_path)
vocab = pkl.load(open(config.vocab_path, 'rb'))
# 按行读取并解析成([...], 标签, 句子长度)存入contents返回
def load_dataset(path, pad_size=10):
contents = []
with open(path, 'r', encoding='UTF-8') as f:
for line in tqdm(f):
lin = line.strip()
if not lin:
continue
# 训练数据集内容和标签按空格隔开
content, label = lin.split(' ')
words_line = []
# 切词
token = tokenizer(content)
seq_len = len(token)
if seq_len < pad_size:
token.extend([PAD] * (pad_size - seq_len))
else:
token = token[:pad_size]
seq_len = pad_size
for word in token:
words_line.append(vocab.get(word, vocab.get(UNK)))
contents.append((words_line, int(label), seq_len))
return contents
train = load_dataset(config.train_path, config.pad_size)
# dev = load_dataset(config.dev_path, config.pad_size)
# test = load_dataset(config.test_path, config.pad_size)
# return vocab, train, dev, test
return vocab, train
iterator.py文件自定义迭代器,build_iterator函数实例化迭代器对象并返回。
import torch
class DatasetIterator(object):
"""
自定义迭代器类, 构造迭代处理的逻辑
注意:
1. 当数据小于batch_size时候, 一次性全部取出
2. 当数据大于等于batch_size时, 先按整批次取出,最后一次取剩余所有数据
"""
def __init__(self, batches, batch_size, device):
self.batch_size = batch_size
self.batches = batches
self.n_batches = len(batches) // batch_size
self.residue = False
# 判断数据集被batch_size整除后是否有剩余
if len(batches) % self.batch_size != 0:
self.residue = True
self.index = 0
self.device = device
def _to_tensor(self, datas):
"""
将数据转化为tensor, 入参datas格式[([...], 标签, 句子长度), ...]
"""
# 取训练数据
x = torch.LongTensor([_[0] for _ in datas]).to(self.device)
# 取数据对应的标签
y = torch.LongTensor([_[1] for _ in datas]).to(self.device)
# 取每个数据句子的长度,超过pad_size的已经设为pad_size
seq_len = torch.LongTensor([_[2] for _ in datas]).to(self.device)
return (x, seq_len), y
def __next__(self):
"""
自定义迭代器每次取数据的大小
"""
# 取最后一批不满batch_size的数据
if self.residue and self.index == self.n_batches:
batches = self.batches[self.index * self.batch_size: len(self.batches)]
self.index += 1
batches = self._to_tensor(batches)
return batches
# 迭代结束
elif self.index >= self.n_batches:
self.index = 0
raise StopIteration
# 整批次取数据
else:
batches = self.batches[self.index * self.batch_size: (self.index + 1) * self.batch_size]
self.index += 1
batches = self._to_tensor(batches)
return batches
def __iter__(self):
return self
def build_iterator(dataset, config):
iter = DatasetIterator(dataset, config.batch_size, config.device)
return iter
main.py文件打印自定义迭代器中的数据。
from iterator import build_iterator
from build_dataset import build_dataset
from config import Config
config = Config()
vocab, train_data = build_dataset(config, False)
train_iter = build_iterator(train_data, config)
# 需要在Config类中增加验证集和测试集数据路径
# dev_iter = build_iterator(dev_data, config)
# test_iter = build_iterator(test_data, config)
for i, (trains, labels) in enumerate(train_iter):
print(f"第{i + 1}次取数据")
print(f"每次迭代取训练集数据的大小={len(trains[0])}")
print(f"每次迭代取训练集标签的大小={len(labels)}")
# print(f"labels={labels}")
train.txt文件共120条数据,batch_size=32,那么就是执行4次,前3次每次取32条数据,最后一次是剩余的24条数据,打印结果如下:
总结
- 迭代器可以比较灵活的构造数据返回的格式和大小。
- 迭代器必须实现__iter__和__next__两个特殊的方法,前者返迭代器本身后者是对数据的具体操作。