pytorch读取自己的数据集_给训练踩踩油门 —— Pytorch 加速数据读取

1cd3c92c010d422f4600d0e0118d2671.png

需求

最近在训练 coco 数据集,训练集就有 11 万张,训练一个 epoch 就要将近 100 分钟,训练 100 个 epoch,就需要 7 天!这实在是太慢了。

经过观察,发现训练时 GPU 利用率不是很稳定,每训练 5 秒,利用率都要从 100% 掉到 0% 一两秒,初步判断是数据读取那块出现了瓶颈。于是经过调研和实验,制定了下列解决方案。

解决方案

(1)prefetch_generator

使用 prefetch_generator 库在后台加载下一 batch 的数据。

安装:

pip install prefetch_generator

使用:

# 新建DataLoaderX类
from torch.utils.data import DataLoader
from prefetch_generator import BackgroundGenerator

class DataLoaderX(DataLoader):

    def __iter__(self):
        return BackgroundGenerator(super().__iter__())

然后用 DataLoaderX 替换原本的 DataLoader

提速原因:

原本 PyTorch 默认的 DataLoader 会创建一些 worker 线程来预读取新的数据,但是除非这些线程的数据全部都被清空,这些线程才会读下一批数据。
使用 prefetch_generator,我们可以保证线程不会等待,每个线程都总有至少一个数据在加载。

(2)data_prefetcher

更新:经评论区提醒,有同学在用这个技术的时候遇到显存溢出的问题,见 apex issues ,大家用的时候注意一下。

使用 data_prefetcher 新开 cuda stream 来拷贝 tensor 到 gpu。

使用:

class DataPrefetcher():
    def __init__(self, loader, opt):
        
  • 0
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值