-
起因:
- 最近在处理数据,训练深度学习模型的时候,发现使用pytorch中的dataloader的时候会占用大量缓存,拖慢数据的处理速度,凑巧在学习 李沐老师的深度学习课程时,李沐老师给出了解决方案,
-
解决方案如下:
-
将数据保存硬盘上,每次读取两个batch,一个batch直接用于训练,一个batch备用,这样可以尽可能少的占用缓存,同时保证运行速度
-
以下是我的解决方案,主要思想是基于队列先进先出的思想,首先使用python构建队列:
class Queue:
def __init__(self):
self.items = []
def enqueue(self, item):
### 将数据载入队列中 ###
self.items.append(item)
def dequeue(self):
### 先进先出 ###
return self.items.pop(0)
def empty(self):
### 判断队列是否为空 ###
return self.size() == 0
def size(self):
### 队列中的元素个数 ###
return len(self.items)
- 基于队列构建数据的循环读取,每次队列中只加载两个batch,代码如下:
class Queue_Data_Load:
def __init__(self, path, sampNum, seled_file, max_len=2):
"""
sampNum: list batch的编号列表
seled_file: list 全局变量,已选择过的编号
max_len: int 最大储存batch数量
path: str 数据储存位置 "/YourDataPath/{sample}.np"
"""
self.data_queue = Queue()
self.samp_ls = sampNum
self.max_len = max_len
self.data_can_in = True
self.path = path
self.seled_file = seled_file
def data_read(self):
'''随机选取保证随机性'''
isam = random.choice(self.samp_ls)
while (len(self.seled_file) < len(self.samp_ls)):
if isam not in self.seled_file:
self.seled_file.append(isam)
break
else:
isam = random.choice(self.samp_ls)
else:
self.data_can_in = False
filePath = self.path.format(sample=isam)
### 这里可以添加按地址读取数据的程序 ###
return filePath
def data_queue_load(self):
"""数据读取 只在队列中数据量小于max_len时执行"""
while self.data_queue.size() < self.max_len and self.data_can_in:
self.data_queue.enqueue(self.data_read())
def data_queue_out(self):
self.data_queue_load()
if (self.data_queue.size() == self.max_len and self.data_can_in):
### 当队列中数据小于max_len 且数据还可以读取时 ###
return self.data_queue.dequeue()
if not (self.data_can_in or self.data_queue.empty()):
### 当没有剩余数据时,保证队列中所有数据都输出 ###
return self.data_queue.dequeue()
- 完整代码及演示如下:
import random
class Queue:
def __init__(self):
self.items = []
def enqueue(self, item):
### 将数据载入队列中 ###
self.items.append(item)
def dequeue(self):
### 先进先出 ###
return self.items.pop(0)
def empty(self):
### 判断队列是否为空 ###
return self.size() == 0
def size(self):
### 队列中的元素个数 ###
return len(self.items)
class Queue_Data_Load:
def __init__(self, path, sampNum, seled_file, max_len=2):
"""
sampNum: list batch的编号列表
seled_file: list 全局变量,已选择过的编号
max_len: int 最大储存batch数量
path: str 数据储存位置 "/YourDataPath/{sample}.np"
"""
self.data_queue = Queue()
self.samp_ls = sampNum
self.max_len = max_len
self.data_can_in = True
self.path = path
self.seled_file = seled_file
def data_read(self):
'''随机选取保证随机性'''
isam = random.choice(self.samp_ls)
while (len(self.seled_file) < len(self.samp_ls)):
if isam not in self.seled_file:
self.seled_file.append(isam)
break
else:
isam = random.choice(self.samp_ls)
else:
self.data_can_in = False
filePath = self.path.format(sample=isam)
return filePath
def data_queue_load(self):
"""数据读取 只在队列中数据量小于max_len时执行"""
while self.data_queue.size() < self.max_len and self.data_can_in:
self.data_queue.enqueue(self.data_read())
def data_queue_out(self):
self.data_queue_load()
if (self.data_queue.size() == self.max_len and self.data_can_in):
### 当队列中数据小于max_len 且数据还可以读取时 ###
return self.data_queue.dequeue()
if not (self.data_can_in or self.data_queue.empty()):
### 当没有剩余数据时,保证队列中所有数据都输出 ###
return self.data_queue.dequeue()
if __name__ == '__main__':
path = 'data/{sample}'
sampNum = list(range(0,10))
global seled_file
seled_file = []
qdl = Queue_Data_Load(path, sampNum, seled_file)
for i in range(12):
print(qdl.data_queue_out(), qdl.data_queue.size())