使用pytorch的IterableDataset类shuffle迭代处理大数据集

主要思路

  • 对于需要拆分train和test的数据集,可以先将数据集拆分为多个小数据集。然后使用random.sample()函数抽样train和test数据集。
  • 使用torch.utils.data.IterableDataset,重写__iter__函数,流读取数据集。
  • 使用蓄水池抽样法实现shuffle。

以Criteo 4500w为例。

数据集拆分

import os
import sys

data_root_path = "./data/criteo_4500/"

num_seg_line = 204800

# mkdir
def mkdir(path):
    if not os.path.exists(path):
        os.mkdir(path)
        
def segment(root_path, dataset_type):
    file = open(os.path.join(root_path, f'{dataset_type}.txt'),'r')
    line = file.readline()
    i=0
    while line:
        with open(os.path.join(root_path, f"segmented_{dataset_type}/{dataset_type}_{i}.txt"),'w') as subdataset:
            for _ in range(num_seg_line):
                subdataset.write(line)
                line = file.readline()
                if not line:
                    break
        i+=1
    
if __name__ == "__main__":
    mkdir(os.path.join(data_root_path,"segmented_train/"))
    mkdir(os.path.join(data_root_path,"segmented_test/"))
    segment(data_root_path, 'train')
    segment(data_root_path, 'test')

创建word2id.pkl

生成单词到id的映射字典。

import pickle

word2id = dict()
count = 0
with open('../data/criteo_4500/train.txt','r') as f:
    l = f.readline().strip()
    while l:
        features = l.split('\t')

        for feature in features[14:]:
            if feature == "": feature = "***"
            if not word2id.get(feature): word2id[feature] = len(word2id)
        if count % 1000000 ==0: print(count)
        count+=1
        l = f.readline().strip()

print(len(word2id))
             
f_save = open('../data/word2id.pkl', 'wb')
pickle.dump(word2id, f_save)
f_save.close()
 
# # 读取
# f_read = open('./data/word2id.pkl', 'rb')
# dict2 = pickle.load(f_read)
# print(dict2)
# f_read.close()

使用IterableDataset流读取数据集

import torch
from torch.utils.data import IterableDataset
import pandas as pd
import numpy as np
import os
import numpy as np
from tqdm import tqdm
import logging
from sklearn.model_selection import train_test_split
import pickle
import random
from itertools import cycle, chain, islice
import re
import torch.distributed as dist
import math
# NAMES = ['label', 'I1', 'I2', 'I3', 'I4', 'I5', 'I6', 'I7', 'I8', 'I9', 'I10', 'I11',
#          'I12', 'I13', 'C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'C7', 'C8', 'C9', 'C10', 'C11',
#          'C12', 'C13', 'C14', 'C15', 'C16', 'C17', 'C18', 'C19', 'C20', 'C21', 'C22',
#          'C23', 'C24', 'C25', 'C26']
def split_train_test(root_path, train_percent = 0.8):
    file_list = os.listdir(root_path)
    shuffle_list = random.sample(file_list, len(file_list))
    # root_path = "./data/criteo_4500/segmented_train/"
    train_file_list = [os.path.join(root_path, filename) for filename in shuffle_list[: int(len(shuffle_list)*train_percent)]]
    test_file_list = [os.path.join(root_path, filename) for filename in shuffle_list[int(len(shuffle_list)*train_percent):]]
    return train_file_list, test_file_list

def create_iterable_criteo_dataset(data_root, batch_size = 1024, shuffle=True, buffer_size= 100000, train_percent=0.9, distributed = False):
    """Load small criteo data(sample num) without splitting "train.txt".
    Note: If you want to load all data in the memory, please set "read_part" to False.
    Args:
        :param file: A string. dataset's path.
        :param read_part: A boolean. whether to read part of it.
        :param sample_num: A scalar. the number of instances if read_part is True.
    :return: custom dataset: criteo_dataset
             feature_spec, such as [-1, -1, ..., 256]
    """
    num_int_feat = 13
    with open(os.path.join(data_root,'word2id.pkl'), 'rb')  as file:
        word2id = pickle.load(file)

    # train_file_list, test_file_list = split_train_test(os.path.join(data_root, 'segmented_train'),train_percent)
    train_file_list = [os.path.join(data_root, "segmented_train", filename) for filename in os.listdir(os.path.join(data_root,"segmented_train/"))]
    test_file_list = [os.path.join(data_root, "segmented_val", filename) for filename in os.listdir(os.path.join(data_root,"segmented_val/"))]

    train_dataset = IterableCriteoDataset(train_file_list, batch_size = batch_size, word2id = word2id, shuffle=shuffle, buffer_size= buffer_size, distributed = distributed )
    val_dataset = IterableCriteoDataset(test_file_list, batch_size = batch_size, word2id = word2id, shuffle=shuffle, buffer_size= buffer_size, distributed = distributed )

    logging.info("id_hash_set:{}".format(len(word2id)))
    return train_dataset, val_dataset, word2id, num_int_feat

class IterableCriteoDataset(IterableDataset):
    """
    Custom dataset class for Criteo dataset in order to use efficient
    dataloader tool provided by PyTorch.
    """
    def __init__(self, file_list:list, batch_size, word2id, shuffle = True, buffer_size = 100000, distributed = False,epoch = 0, seed = 5):
        # self.file_path = file_path
        print("IterableDataset")
        self.file_line_num = self.get_file_line_num(file_list[-1])
        self.file_list = file_list
        self.batch_size = batch_size
        self.buffer_size = buffer_size
        self.shuffle = shuffle
        self.word2id = word2id
        self.distributed = distributed
        if self.distributed:
            self.world_size = dist.get_world_size()
            self.file_num_per_rank = math.ceil(len(self.file_list)/self.world_size)
        self.epoch = epoch
        self.seed = seed
        self.generator = torch.Generator()
        
    def set_epoch(self, epoch):
        self.epoch = epoch
        
    def file_mapper(self, file_name):
        f = open(file_name,'r')
        line = f.readline()
        while line:
            yield self.process_line(line)
            line = f.readline()
        f.close()
    
    def process_line(self, line):
        items = re.split('\t', line.replace('\n',""))
        assert len(items)==40
        y = int(items[0])
        x = [0. if item=='' else int(item) for item in items[1:14]]
        for f in items[14:]:
            x.append(self.word2id["***" if f=="" else f])
        return torch.tensor(x, dtype=torch.float32),torch.tensor(y, dtype=torch.float32)

    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()
        if worker_info:
            worker_total_num = worker_info.num_workers
            worker_id = worker_info.id
        else:
            worker_id = 0
            worker_total_num = 1
        self.generator.manual_seed(self.seed + self.epoch)
        if self.distributed:
            local_rank = dist.get_rank()
            idx = torch.randperm(len(self.file_list), generator=self.generator).numpy()
            self.file_list = (np.array(self.file_list)[idx]).tolist()
            file_itr = self.file_list[local_rank * self.file_num_per_rank :(local_rank+1)*self.file_num_per_rank]
            if len(file_itr) < self.file_num_per_rank:
                file_itr.extend(random.sample(file_itr, self.file_num_per_rank-len(file_itr)))
            # print(f"rank: {local_rank}, file_list: {self.file_list}, rank_file_list: {self.file_list[local_rank * self.file_num_per_rank :(local_rank+1)*self.file_num_per_rank]}")
        else:
            idx = torch.randperm(len(self.file_list), generator=self.generator).numpy()
            self.file_list = (np.array(self.file_list)[idx]).tolist()
            file_itr = self.file_list
        
        file_mapped_itr = chain.from_iterable(map(self.file_mapper, file_itr))
        file_mapped_itr = islice(file_mapped_itr, worker_id, None, worker_total_num)
        
        if self.shuffle:
            return self._shuffle(file_mapped_itr)
        else:
            return file_mapped_itr
    
    def get_file_line_num(self, path):
        return len(open(path,'rU').readlines())
    
    def __len__(self):
        if self.distributed:
            return self.file_num_per_rank * self.file_line_num
        return len(self.file_list) * self.file_line_num
        
    def generate_random_num(self):
        while True:
            random_nums = random.sample(range(self.buffer_size), self.batch_size)
            yield from random_nums
           
    def _shuffle(self, mapped_itr):
        buffer = []
        for dt in mapped_itr:
        
            if len(buffer) < self.buffer_size:
                buffer.append(dt)
            else:
                i = next(self.generate_random_num())
                yield buffer[i]
                buffer[i] = dt
        random.shuffle(buffer)
        yield from buffer

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
引用和介绍了torch.Tensor()和torch.tensor()的区别。torch.Tensor()是一个,而torch.tensor()是一个函数。torch.Tensor()可以接受多种型的数据作为输入,包括list、tuple、array、scalar等。而torch.tensor()可以从数据输入中做拷贝,并根据原始数据型生成相应的torch.LongTensor、torch.FloatTensor、torch.DoubleTensor。举例来说,当输入是[1, 2]时,torch.tensor()将生成一个torch.LongTensor,而当输入是[1., 2.]时,torch.tensor()将生成一个torch.FloatTensor。同时,可以使用torch.tensor()将numpy数组转换为相应型的torch tensor。<span class="em">1</span><span class="em">2</span><span class="em">3</span> #### 引用[.reference_title] - *1* [torch.tensor和torch.Tensor的区别](https://blog.csdn.net/qq_36930266/article/details/104602792)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_2"}}] [.reference_item style="max-width: 50%"] - *2* *3* [torch.Tensor和torch.tensor的区别](https://blog.csdn.net/weixin_42018112/article/details/91383574)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_2"}}] [.reference_item style="max-width: 50%"] [ .reference_list ]
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值