深度学习--基于队列的数据随机载入

  • 起因:

    • 最近在处理数据,训练深度学习模型的时候,发现使用pytorch中的dataloader的时候会占用大量缓存,拖慢数据的处理速度,凑巧在学习 李沐老师的深度学习课程时,李沐老师给出了解决方案,
  • 解决方案如下:

  • 将数据保存硬盘上,每次读取两个batch,一个batch直接用于训练,一个batch备用,这样可以尽可能少的占用缓存,同时保证运行速度

  • 以下是我的解决方案,主要思想是基于队列先进先出的思想,首先使用python构建队列:

class Queue:
    def __init__(self):
        self.items = []
    
    def enqueue(self, item):
        ### 将数据载入队列中 ###
        self.items.append(item)
    
    def dequeue(self):
        ### 先进先出 ###
        return self.items.pop(0)
    
    def empty(self):
        ### 判断队列是否为空 ###
        return self.size() == 0 

    def size(self):
        ### 队列中的元素个数 ###
        return len(self.items)
  • 基于队列构建数据的循环读取,每次队列中只加载两个batch,代码如下:
class Queue_Data_Load:
    def __init__(self, path, sampNum, seled_file, max_len=2):
        """
        sampNum: list batch的编号列表
        seled_file: list 全局变量,已选择过的编号
        max_len: int 最大储存batch数量
        path: str 数据储存位置 "/YourDataPath/{sample}.np"
        """
        self.data_queue = Queue()
        self.samp_ls = sampNum
        self.max_len = max_len
        self.data_can_in = True
        self.path = path
        self.seled_file = seled_file

    def data_read(self):
        
        '''随机选取保证随机性'''
        isam = random.choice(self.samp_ls)
        while (len(self.seled_file) < len(self.samp_ls)):
            if isam not in self.seled_file:
                self.seled_file.append(isam)
                break
            else:
                isam = random.choice(self.samp_ls)
        else:
            self.data_can_in = False
        filePath = self.path.format(sample=isam)
        ### 这里可以添加按地址读取数据的程序 ###
        return filePath 
    
    def data_queue_load(self):
        """数据读取 只在队列中数据量小于max_len时执行"""
        while self.data_queue.size() < self.max_len and self.data_can_in:
            self.data_queue.enqueue(self.data_read())
    
    def data_queue_out(self):
        self.data_queue_load()
        if (self.data_queue.size() == self.max_len and self.data_can_in): 
            ### 当队列中数据小于max_len 且数据还可以读取时 ###
            return self.data_queue.dequeue()
        if not (self.data_can_in or self.data_queue.empty()):
            ### 当没有剩余数据时,保证队列中所有数据都输出 ###
            return self.data_queue.dequeue()
  • 完整代码及演示如下:
import random

class Queue:
    def __init__(self):
        self.items = []
    
    def enqueue(self, item):
        ### 将数据载入队列中 ###
        self.items.append(item)
    
    def dequeue(self):
        ### 先进先出 ###
        return self.items.pop(0)
    
    def empty(self):
        ### 判断队列是否为空 ###
        return self.size() == 0 

    def size(self):
        ### 队列中的元素个数 ###
        return len(self.items)
    
class Queue_Data_Load:
    def __init__(self, path, sampNum, seled_file, max_len=2):
        """
        sampNum: list batch的编号列表
        seled_file: list 全局变量,已选择过的编号
        max_len: int 最大储存batch数量
        path: str 数据储存位置 "/YourDataPath/{sample}.np"
        """
        self.data_queue = Queue()
        self.samp_ls = sampNum
        self.max_len = max_len
        self.data_can_in = True
        self.path = path
        self.seled_file = seled_file

    def data_read(self):
        
        '''随机选取保证随机性'''
        isam = random.choice(self.samp_ls)
        while (len(self.seled_file) < len(self.samp_ls)):
            if isam not in self.seled_file:
                self.seled_file.append(isam)
                break
            else:
                isam = random.choice(self.samp_ls)
        else:
            self.data_can_in = False
        filePath = self.path.format(sample=isam)
        return filePath 
    
    def data_queue_load(self):
        """数据读取 只在队列中数据量小于max_len时执行"""
        while self.data_queue.size() < self.max_len and self.data_can_in:
            self.data_queue.enqueue(self.data_read())
    
    def data_queue_out(self):
        self.data_queue_load()
        if (self.data_queue.size() == self.max_len and self.data_can_in): 
            ### 当队列中数据小于max_len 且数据还可以读取时 ###
            return self.data_queue.dequeue()
        if not (self.data_can_in or self.data_queue.empty()):
            ### 当没有剩余数据时,保证队列中所有数据都输出 ###
            return self.data_queue.dequeue()
 
        
if __name__ == '__main__':
    path = 'data/{sample}'
    sampNum = list(range(0,10))
    global seled_file 
    seled_file = []
    qdl = Queue_Data_Load(path, sampNum, seled_file)
    for i in range(12):
        print(qdl.data_queue_out(), qdl.data_queue.size())
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值