Pytorch IterableDataset的使用

Pytorch IterableDataset的使用

背景

当数据量特别大,无法一次性load进内存时,Pytorch里的Dataset就无法胜任了,此时需要使用IterableDataset.

基本用法

只需要实现__init__()__iter__()__len__(),模版如下:

from torch.utils.data import IterableDataset, DataLoader

class MyIterableDataset(IterableDataset):
	def __init__(self):
		# 实现初始化代码
		pass
	
	def __iter__(self):
		# 返回一个数据的迭代器
		pass
	
	def __len__(self):
		# 返回数据长度
		pass

mydataset = MyIterableDataset()  # 可迭代对象
mydataloader = DataLoader(mydataset, shuffle=False, batch_size=batch_size, num_workers=num_workers)  # shuffle必须要是False

一个例子

读取CNN_Dailymail摘要数据集

class SummaryDataset(IterableDataset):                                          
                                                                                
    def __init__(self,                                                          
                 file_path: str                                                 
    ):                                                                          
        super(SummaryDataset).__init__()                                        
        self.file_path = file_path                                              
        self.info  = self._get_file_info(file_path)                             
        self.start = self.info['start']                                         
        self.end   = self.info['end']                                           
                                                                                
    def __iter__(self):                                                         
        worker_info = torch.utils.data.get_worker_info()                        
        if worker_info is None:  # single worker                                
            iter_start = self.start                                             
            iter_end   = self.end                                               
        else:  # multiple workers                                               
            per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers)))
            worker_id = worker_info.id                                          
            iter_start = self.start + worker_id * per_worker                    
            iter_end = min(iter_start + per_worker, self.end)                   
        sample_iterator = self._sample_generator(iter_start, iter_end)          
        return sample_iterator                                                  
                                                                                
    def __len__(self):                                                          
        return self.end - self.start                                            
                                                                                
    def _get_file_info(self,                                                    
                       file_path                                                
    ):                                                                          
        info = {                                                                
            "start": 1,                                                         
            "end": 0,                                                           
            "id_colum": 0,                                                      
            "article_colum": 1,                                                 
            "summary_colum": 2                                                  
        }                                                                       
        with open(file_path, 'r') as fin:                                       
            for _ in enumerate(fin):                                            
                info['end'] += 1                                                                                            
        return info                                                             
                                                                                
    def _sample_generator(self, start, end):                                    
        id_c, art_c, sum_c = self.info['id_colum'], self.info['article_colum'], self.info['summary_colum']
        with open(self.file_path, 'r') as fin:                                  
            for i, line in enumerate(fin):                                      
                if i < start: continue                                          
                if i >= end: return StopIteration()                             
                items = line.strip().split('\t')                                
                sample = {"id": items[id_c], "article": items[art_c], "summary": items[sum_c]}
                yield sample

train_dataset = SummaryDataset(args.train_dataset)                                                  
train_dataloader = DataLoader(train_dataset, shuffle=False, batch_size=args.batch_size, num_workers=args.num_workers) 

如果要配合DistributedDataParallel进行多进程分布式训练,每个进程load不同段的数据,尽量不要在subprocess里通过num_workers再开子进程,容易出问题,推荐以下写法:

class SummaryDataset(IterableDataset):                                          
                                                                                
    def __init__(self,                                                          
                 file_path: str,                                                
                 rank,                                                          
                 world_size                                                     
    ):                                                                          
        super(SummaryDataset).__init__()                                        
        self.file_path = file_path                                              
        self.info  = self._get_file_info(file_path)                             
        self.start = self.info['start']                                         
        self.end   = self.info['end']                                           
                                                                                
        self.rank  = rank                                                       
        self.world_size = world_size                                            
                                                                                
        self.per_worker = int(math.floor((self.end - self.start) / float(self.world_size)))
        self.iter_start = self.start + self.rank * self.per_worker              
        self.iter_end = min(self.iter_start + self.per_worker, self.end)        
                                                                                
    def __iter__(self):                                                         
        sample_iterator = self._sample_generator(self.iter_start, self.iter_end) 
        return sample_iterator                                                  
                                                                                
    def __len__(self):                                                          
        return self.iter_end - self.iter_start                                  
                                                                                
    def _get_file_info(self,                                                    
                       file_path                                                
    ):                                                                          
        info = {                                                                
            "start": 1,                                                         
            "end": 0,                                                           
            "id_colum": 0,                                                      
            "article_colum": 1,                                                 
            "summary_colum": 2                                                  
        }                                                                       
        with open(file_path, 'r') as fin:                                       
            for _ in enumerate(fin):                                            
                info['end'] += 1                                                                                           
        return info                                                             
                                                                                
    def _sample_generator(self, start, end):                                    
        id_c, art_c, sum_c = self.info['id_colum'], self.info['article_colum'], self.info['summary_colum']
        with open(self.file_path, 'r') as fin:                                  
            for i, line in enumerate(fin):                                      
                if i < start: continue                                          
                if i >= end: return StopIteration()                             
                items = line.strip().split('\t')                                
                sample = {"id": items[id_c], "article": items[art_c], "summary": items[sum_c]}
                yield sample

def train_worker(rank, args):
	# 子进程
	...
	train_dataset = SummaryDataset(args.train_dataset, rank, args.world_size)
	rain_dataloader = DataLoader(train_dataset, shuffle=False, batch_size=args.batch_size, num_workers=0)
	...
  • 7
    点赞
  • 15
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值