Pytorch IterableDataset的使用

Pytorch IterableDataset的使用

背景

当数据量特别大,无法一次性load进内存时,Pytorch里的Dataset就无法胜任了,此时需要使用IterableDataset.

基本用法

只需要实现__init__()__iter__()__len__(),模版如下:

from torch.utils.data import IterableDataset, DataLoader

class MyIterableDataset(IterableDataset):
	def __init__(self):
		# 实现初始化代码
		pass
	
	def __iter__(self):
		# 返回一个数据的迭代器
		pass
	
	def __len__(self):
		# 返回数据长度
		pass

mydataset = MyIterableDataset()  # 可迭代对象
mydataloader = DataLoader(mydataset, shuffle=False, batch_size=batch_size, num_workers=num_workers)  # shuffle必须要是False

一个例子

读取CNN_Dailymail摘要数据集

class SummaryDataset(IterableDataset):                                          
                                                                                
    def __init__(self,                                                          
                 file_path: str                                                 
    ):                                                                          
        super(SummaryDataset).__init__()                                        
        self.file_path = file_path                                              
        self.info  = self._get_file_info(file_path)                             
        self.start = self.info['start']                                         
        self.end   = self.info['end']                                           
                                                                                
    def __iter__(self):                                                         
        worker_info = torch.utils.data.get_worker_info()                        
        if worker_info is None:  # single worker                                
            iter_start = self.start                                             
            iter_end   = self.end                                               
        else:  # multiple workers                                               
            per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers)))
            worker_id = worker_info.id                                          
            iter_start = self.start + worker_id * per_worker                    
            iter_end = min(iter_start + per_worker, self.end)                   
        sample_iterator = self._sample_generator(iter_start, iter_end)          
        return sample_iterator                                                  
                                                                                
    def __len__(self):                                                          
        return self.end - self.start                                            
                                                                                
    def _get_file_info(self,                                                    
                       file_path                                                
    ):                                                                          
        info = {                                                                
            "start": 1,                                                         
            "end": 0,                                                           
            "id_colum": 0,                                                      
            "article_colum": 1,                                                 
            "summary_colum": 2                                                  
        }                                                                       
        with open(file_path, 'r') as fin:                                       
            for _ in enumerate(fin):                                            
                info['end'] += 1                                                                                            
        return info                                                             
                                                                                
    def _sample_generator(self, start, end):                                    
        id_c, art_c, sum_c = self.info['id_colum'], self.info['article_colum'], self.info['summary_colum']
        with open(self.file_path, 'r') as fin:                                  
            for i, line in enumerate(fin):                                      
                if i < start: continue                                          
                if i >= end: return StopIteration()                             
                items = line.strip().split('\t')                                
                sample = {"id": items[id_c], "article": items[art_c], "summary": items[sum_c]}
                yield sample

train_dataset = SummaryDataset(args.train_dataset)                                                  
train_dataloader = DataLoader(train_dataset, shuffle=False, batch_size=args.batch_size, num_workers=args.num_workers) 

如果要配合DistributedDataParallel进行多进程分布式训练,每个进程load不同段的数据,尽量不要在subprocess里通过num_workers再开子进程,容易出问题,推荐以下写法:

class SummaryDataset(IterableDataset):                                          
                                                                                
    def __init__(self,                                                          
                 file_path: str,                                                
                 rank,                                                          
                 world_size                                                     
    ):                                                                          
        super(SummaryDataset).__init__()                                        
        self.file_path = file_path                                              
        self.info  = self._get_file_info(file_path)                             
        self.start = self.info['start']                                         
        self.end   = self.info['end']                                           
                                                                                
        self.rank  = rank                                                       
        self.world_size = world_size                                            
                                                                                
        self.per_worker = int(math.floor((self.end - self.start) / float(self.world_size)))
        self.iter_start = self.start + self.rank * self.per_worker              
        self.iter_end = min(self.iter_start + self.per_worker, self.end)        
                                                                                
    def __iter__(self):                                                         
        sample_iterator = self._sample_generator(self.iter_start, self.iter_end) 
        return sample_iterator                                                  
                                                                                
    def __len__(self):                                                          
        return self.iter_end - self.iter_start                                  
                                                                                
    def _get_file_info(self,                                                    
                       file_path                                                
    ):                                                                          
        info = {                                                                
            "start": 1,                                                         
            "end": 0,                                                           
            "id_colum": 0,                                                      
            "article_colum": 1,                                                 
            "summary_colum": 2                                                  
        }                                                                       
        with open(file_path, 'r') as fin:                                       
            for _ in enumerate(fin):                                            
                info['end'] += 1                                                                                           
        return info                                                             
                                                                                
    def _sample_generator(self, start, end):                                    
        id_c, art_c, sum_c = self.info['id_colum'], self.info['article_colum'], self.info['summary_colum']
        with open(self.file_path, 'r') as fin:                                  
            for i, line in enumerate(fin):                                      
                if i < start: continue                                          
                if i >= end: return StopIteration()                             
                items = line.strip().split('\t')                                
                sample = {"id": items[id_c], "article": items[art_c], "summary": items[sum_c]}
                yield sample

def train_worker(rank, args):
	# 子进程
	...
	train_dataset = SummaryDataset(args.train_dataset, rank, args.world_size)
	rain_dataloader = DataLoader(train_dataset, shuffle=False, batch_size=args.batch_size, num_workers=0)
	...
  • 6
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
首先,为了安装和使用PyTorch,你需要选择合适的操作系统、Python版本和CUDA版本。根据引用中提供的信息,你可以选择稳定版的Windows操作系统、Python语言和CUDA版本11.3。然后,你可以按照以下步骤安装和使用PyTorch: 1. 打开终端或命令提示符,并创建一个新的PyTorch环境。你可以使用Anaconda或Miniconda来管理你的环境。在终端中执行以下命令来创建一个名为pytorch的新环境,并安装PyTorch和相关软件包: ```bash conda create --name pytorch python=3.8 conda activate pytorch conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch ``` 请确保你已经复制了上述conda命令,并将其粘贴到终端中执行。等待安装完成。 2. 安装PyCharm集成开发环境(IDE)。你可以从官方网站下载PyCharm的稳定版本并进行安装。安装过程非常简单,只需按照提示进行操作即可。根据引用中的信息,你可以开始验证PyTorch是否成功安装。 3. 打开PyCharm并创建一个新的项目。在项目设置中,选择之前创建的PyTorch环境(pytorch)作为项目的解释器。 4. 在PyCharm的终端中,确保你已经激活pytorch环境。如果没有激活,你可以使用以下命令激活它: ```bash conda activate pytorch ``` 5. 在激活的pytorch环境中,你可以使用以下命令来安装PyTorch: ```bash conda install pytorch torchvision cudatoolkit=11.3 ``` 等待安装完成。 6. 验证PyTorch是否成功安装。在PyCharm的终端中,输入以下Python代码并执行: ```python import torch print(torch.__version__) ``` 如果你看到了PyTorch的版本号输出,那么恭喜你,PyTorch安装成功了! 请注意,以上步骤基于引用和中提供的信息,并假设你已经正确安装了Anaconda或Miniconda和PyCharm。如果你遇到了任何问题,可以参考官方文档或在相关社区寻求帮助。<span class="em">1</span><span class="em">2</span><span class="em">3</span> #### 引用[.reference_title] - *1* *2* *3* [pytorch安装教程新手入门](https://blog.csdn.net/qq_45547409/article/details/127182762)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_1"}}] [.reference_item style="max-width: 100%"] [ .reference_list ]

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值