主要思路
- 对于需要拆分train和test的数据集,可以先将数据集拆分为多个小数据集。然后使用
random.sample()
函数抽样train和test数据集。 - 使用
torch.utils.data.IterableDataset
,重写__iter__
函数,流读取数据集。 - 使用蓄水池抽样法实现shuffle。
以Criteo 4500w为例。
数据集拆分
import os
import sys
data_root_path = "./data/criteo_4500/"
num_seg_line = 204800
# mkdir
def mkdir(path):
if not os.path.exists(path):
os.mkdir(path)
def segment(root_path, dataset_type):
file = open(os.path.join(root_path, f'{dataset_type}.txt'),'r')
line = file.readline()
i=0
while line:
with open(os.path.join(root_path, f"segmented_{dataset_type}/{dataset_type}_{i}.txt"),'w') as subdataset:
for _ in range(num_seg_line):
subdataset.write(line)
line = file.readline()
if not line:
break
i+=1
if __name__ == "__main__":
mkdir(os.path.join(data_root_path,"segmented_train/"))
mkdir(os.path.join(data_root_path,"segmented_test/"))
segment(data_root_path, 'train')
segment(data_root_path, 'test')
创建word2id.pkl
生成单词到id的映射字典。
import pickle
word2id = dict()
count = 0
with open('../data/criteo_4500/train.txt','r') as f:
l = f.readline().strip()
while l:
features = l.split('\t')
for feature in features[14:]:
if feature == "": feature = "***"
if not word2id.get(feature): word2id[feature] = len(word2id)
if count % 1000000 ==0: print(count)
count+=1
l = f.readline().strip()
print(len(word2id))
f_save = open('../data/word2id.pkl', 'wb')
pickle.dump(word2id, f_save)
f_save.close()
# # 读取
# f_read = open('./data/word2id.pkl', 'rb')
# dict2 = pickle.load(f_read)
# print(dict2)
# f_read.close()
使用IterableDataset流读取数据集
import torch
from torch.utils.data import IterableDataset
import pandas as pd
import numpy as np
import os
import numpy as np
from tqdm import tqdm
import logging
from sklearn.model_selection import train_test_split
import pickle
import random
from itertools import cycle, chain, islice
import re
import torch.distributed as dist
import math
# NAMES = ['label', 'I1', 'I2', 'I3', 'I4', 'I5', 'I6', 'I7', 'I8', 'I9', 'I10', 'I11',
# 'I12', 'I13', 'C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'C7', 'C8', 'C9', 'C10', 'C11',
# 'C12', 'C13', 'C14', 'C15', 'C16', 'C17', 'C18', 'C19', 'C20', 'C21', 'C22',
# 'C23', 'C24', 'C25', 'C26']
def split_train_test(root_path, train_percent = 0.8):
file_list = os.listdir(root_path)
shuffle_list = random.sample(file_list, len(file_list))
# root_path = "./data/criteo_4500/segmented_train/"
train_file_list = [os.path.join(root_path, filename) for filename in shuffle_list[: int(len(shuffle_list)*train_percent)]]
test_file_list = [os.path.join(root_path, filename) for filename in shuffle_list[int(len(shuffle_list)*train_percent):]]
return train_file_list, test_file_list
def create_iterable_criteo_dataset(data_root, batch_size = 1024, shuffle=True, buffer_size= 100000, train_percent=0.9, distributed = False):
"""Load small criteo data(sample num) without splitting "train.txt".
Note: If you want to load all data in the memory, please set "read_part" to False.
Args:
:param file: A string. dataset's path.
:param read_part: A boolean. whether to read part of it.
:param sample_num: A scalar. the number of instances if read_part is True.
:return: custom dataset: criteo_dataset
feature_spec, such as [-1, -1, ..., 256]
"""
num_int_feat = 13
with open(os.path.join(data_root,'word2id.pkl'), 'rb') as file:
word2id = pickle.load(file)
# train_file_list, test_file_list = split_train_test(os.path.join(data_root, 'segmented_train'),train_percent)
train_file_list = [os.path.join(data_root, "segmented_train", filename) for filename in os.listdir(os.path.join(data_root,"segmented_train/"))]
test_file_list = [os.path.join(data_root, "segmented_val", filename) for filename in os.listdir(os.path.join(data_root,"segmented_val/"))]
train_dataset = IterableCriteoDataset(train_file_list, batch_size = batch_size, word2id = word2id, shuffle=shuffle, buffer_size= buffer_size, distributed = distributed )
val_dataset = IterableCriteoDataset(test_file_list, batch_size = batch_size, word2id = word2id, shuffle=shuffle, buffer_size= buffer_size, distributed = distributed )
logging.info("id_hash_set:{}".format(len(word2id)))
return train_dataset, val_dataset, word2id, num_int_feat
class IterableCriteoDataset(IterableDataset):
"""
Custom dataset class for Criteo dataset in order to use efficient
dataloader tool provided by PyTorch.
"""
def __init__(self, file_list:list, batch_size, word2id, shuffle = True, buffer_size = 100000, distributed = False,epoch = 0, seed = 5):
# self.file_path = file_path
print("IterableDataset")
self.file_line_num = self.get_file_line_num(file_list[-1])
self.file_list = file_list
self.batch_size = batch_size
self.buffer_size = buffer_size
self.shuffle = shuffle
self.word2id = word2id
self.distributed = distributed
if self.distributed:
self.world_size = dist.get_world_size()
self.file_num_per_rank = math.ceil(len(self.file_list)/self.world_size)
self.epoch = epoch
self.seed = seed
self.generator = torch.Generator()
def set_epoch(self, epoch):
self.epoch = epoch
def file_mapper(self, file_name):
f = open(file_name,'r')
line = f.readline()
while line:
yield self.process_line(line)
line = f.readline()
f.close()
def process_line(self, line):
items = re.split('\t', line.replace('\n',""))
assert len(items)==40
y = int(items[0])
x = [0. if item=='' else int(item) for item in items[1:14]]
for f in items[14:]:
x.append(self.word2id["***" if f=="" else f])
return torch.tensor(x, dtype=torch.float32),torch.tensor(y, dtype=torch.float32)
def __iter__(self):
worker_info = torch.utils.data.get_worker_info()
if worker_info:
worker_total_num = worker_info.num_workers
worker_id = worker_info.id
else:
worker_id = 0
worker_total_num = 1
self.generator.manual_seed(self.seed + self.epoch)
if self.distributed:
local_rank = dist.get_rank()
idx = torch.randperm(len(self.file_list), generator=self.generator).numpy()
self.file_list = (np.array(self.file_list)[idx]).tolist()
file_itr = self.file_list[local_rank * self.file_num_per_rank :(local_rank+1)*self.file_num_per_rank]
if len(file_itr) < self.file_num_per_rank:
file_itr.extend(random.sample(file_itr, self.file_num_per_rank-len(file_itr)))
# print(f"rank: {local_rank}, file_list: {self.file_list}, rank_file_list: {self.file_list[local_rank * self.file_num_per_rank :(local_rank+1)*self.file_num_per_rank]}")
else:
idx = torch.randperm(len(self.file_list), generator=self.generator).numpy()
self.file_list = (np.array(self.file_list)[idx]).tolist()
file_itr = self.file_list
file_mapped_itr = chain.from_iterable(map(self.file_mapper, file_itr))
file_mapped_itr = islice(file_mapped_itr, worker_id, None, worker_total_num)
if self.shuffle:
return self._shuffle(file_mapped_itr)
else:
return file_mapped_itr
def get_file_line_num(self, path):
return len(open(path,'rU').readlines())
def __len__(self):
if self.distributed:
return self.file_num_per_rank * self.file_line_num
return len(self.file_list) * self.file_line_num
def generate_random_num(self):
while True:
random_nums = random.sample(range(self.buffer_size), self.batch_size)
yield from random_nums
def _shuffle(self, mapped_itr):
buffer = []
for dt in mapped_itr:
if len(buffer) < self.buffer_size:
buffer.append(dt)
else:
i = next(self.generate_random_num())
yield buffer[i]
buffer[i] = dt
random.shuffle(buffer)
yield from buffer