python gpu 显存利用率_python多线程预读取数据,提高GPU利用率,训练效率可提升30%!...

本文介绍了如何通过Python多线程预读取数据,以提高GPU显存利用率,特别是在使用PyTorch训练网络时,解决数据读取阶段的瓶颈,从而提升训练效率30%。通过创建数据队列data_queue,数据读取线程不断填充队列,模型训练线程则从队列中获取数据,利用Queue对象的阻塞性质确保数据供应,避免GPU空闲。
摘要由CSDN通过智能技术生成

在使用pytorch训练网络时,有没有想提高GPU利用率,加快网络训练的同学啊?

GPU利用率不高,很多都是因为单线程的模式:数据读取->模型训练。数据读取是在CPU上进行,如果在此阶段进行图片预处理的步骤较多,那么GPU就只能等了,导致利用率上不去。

So~,可以使用多线程的方法来预读取数据,保证数据每时每刻都已经读取完毕,放在了内存当中,等待模型训练。(如果数据量小,内存足够的话,也可以全部放入内存当中)

下面就说下我的实现思路:

定义一个数据队列data_queue,利用queue.Queue()的阻塞,保证data_queue里面数据的个数保持稳定。数据读取线程往dataqueue存放数据,模型训练线程从data_queue读取数据。

如下图,可以看到Queue对象的阻塞性质:当block=True,timeout=None时,会进行阻塞直到队列中有空位可以put数据。真是利用这一点,现实了在不占用不太多内存的情况下,提高GPU利用率。

8051fd2300642cb6a9afaa14ccca380e.png

下面是我的程序框架:

import torch
import threading
from queue import Queue
import ...

# 预读取数据存储在队列,元素个数设置为5
data_queue = Queue(5)

# 全局标志位,标志所有训练周期是否结束
train_done = False

def read_data(*args):
    
    # 定义dataset
    dataset = LoadImagesAndLabels(train_path, img_size=img_size, augment=True, cls=cls)
    
    # 定义 Dataloader
    dataloader = Dataloader(dataset,batch_size,num_workers,shuffle,...)
    
    while not train_done: 
        # 生成 batch 数据
        for i,(imgs,targets) in enumerate(dataloader):
            imgs = imgs.to(device)
            targets = targets.to(device)
            data_queue.put([imgs,targets]) # put到队列数据,超过所设个数,会阻塞此线程
        
        event.clear()
        while not event.is_set():  # 一个周期结束,等待事件 event.set()
            pass

def train(*args):
    
    # 模型
    model = model()
    
    # 优化器
    optimizer = torch.optim....
    
    # 开始训练模型
    for epoch in range(epochs):
        model.train()
        
        # event 事件开启,开始新一轮数据读取
        event.set()
        while data_queue.qsize()<10:   # 等待读取数据
            pass    
        
        # 每个 epoch 周期
        batch_i = 0 # batch
        while not data_queue.empty():
            imgs,targets = data_queue.get()   # get 获得队列的第一个元素,并删除
            
            # Run model
            pred = model(imgs)
            
            # Compute loss
            loss = compute_loss(pred,targets)
            
            # Backward
            loss.backward()
            loss.step()
            loss.zero_grad()
    
    
if __name__=='__main__':
    
    # 设置好输入参数
    parser = argparse.ArgumentParser()
    parser.add_argument(...)
    opt = parser.parse_args()  
    
    # 创建 read_data()线程 和 train()线程
    event = Event()  # 定义事件,实现数据的循环读取
    
    thread_data = threading.Thread(target=read_data,args=(args))
    time.sleep(0.5)  # 保证队列中已有数据
    thread_train = threading.Thread(target=train,args=(args))
    
    # 等待trian()线程结束
    thread_train.join()

小白代码,大家多多交流~

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值