关于pytorch使用多个dataloader并使用zip和cycle来进行循环时出现的显存泄漏的问题

13 篇文章 2 订阅

关于pytorch使用多个dataloader并使用zip和cycle来进行循环时出现的显存泄漏的问题

如果我们想要在 Pytorch 中同时迭代两个 dataloader 来处理数据,会有两种情况:一是我们按照较短的 dataloader 来迭代,长的 dataloader 超过的部分就丢弃掉;二是比较常见的,我们想要按照较长的 dataloader 来迭代,短的 dataloader 在循环完一遍再循环一遍,直到长的 dataloader 循环完一遍。

两个dataloader的写法及问题的出现

第一种情况很好写,直接用 zip 包一下两个 dataloader 即可:

# ...
dataloaders1 = DataLoader(DummyDataset(0, 100), batch_size=10, shuffle=True)
dataloaders2 = DataLoader(DummyDataset(0, 200), batch_size=10, shuffle=True)
num_epochs = 10

for epoch in range(num_epochs):
    for i, data in enumerate(zip(dataloaders1, dataloaders2)):
        print(data)
        # 开始写你的训练脚本

第二种情况笔者一开始时参考的一篇博客的写法,用 cycle 将较短的 dataloader 包一下:

from itertools import cycle
# ...
dataloaders1 = DataLoader(DummyDataset(0, 100), batch_size=10, shuffle=True)
dataloaders2 = DataLoader(DummyDataset(0, 200), batch_size=10, shuffle=True)
num_epochs = 10

for epoch in range(num_epochs):
    for i, data in enumerate(zip(cycle(dataloaders1), dataloaders2)):
        print(data)
        # 开始写你的训练脚本

是可以运行,但是这样出现了明显显存泄漏的问题,在笔者自己的实验中,显存占用量会随着训练的进行,每轮增加 20M 左右,最终导致显存溢出,程序失败。

解决方法

笔者找了半天,终于在 StackOverflow 的一篇贴子中找到了解决方法,该贴的一个答案指出:cyclezip 的方法确实可能会造成显存泄漏(memory leakage)的问题,尤其是在使用图像数据集时,可以通过以下写法来迭代两个 dataloader 并避免这个问题:

# ...
dataloaders1 = DataLoader(DummyDataset(0, 100), batch_size=10, shuffle=True)
dataloaders2 = DataLoader(DummyDataset(0, 200), batch_size=10, shuffle=True)
num_epochs = 10

for epoch in range(num_epochs):
    dataloader_iterator1 = iter(dataloaders1)
    
    for i, data2 in enumerate(dataloaders2):

        try:
            data1 = next(dataloader_iterator1)
        except StopIteration:
            dataloader_iterator1 = iter(dataloaders1)
            data1 = next(dataloader_iterator1)
        print(data1, data2)

        # 开始你的训练脚本

笔者亲测这种方式是可以正常运行且不会有显存泄漏问题的。

  • 18
    点赞
  • 28
    收藏
    觉得还不错? 一键收藏
  • 5
    评论
好的,下面是使用PyTorch中的Dataset和Dataloader的示例程序: 首先,我们需要导入相关的库: ```python import torch from torch.utils.data import Dataset, DataLoader ``` 接下来我们定义一个数据集类,这个类需要继承自`torch.utils.data.Dataset`,并实现`__len__`和`__getitem__`方法。这里我们假设我们有一个包含10个样本的数据集,每个样本由一个图片和一个标签组成: ```python class MyDataset(Dataset): def __init__(self): self.data = [(torch.randn(3, 224, 224), i) for i in range(10)] def __len__(self): return len(self.data) def __getitem__(self, index): x, y = self.data[index] return x, y ``` 接下来,我们可以使用`MyDataset`类创建一个数据集对象,并使用`DataLoader`类创建一个数据加载器对象。在创建`DataLoader`对象,我们需要指定`batch_size`和`shuffle`参数: ```python my_dataset = MyDataset() my_dataloader = DataLoader(my_dataset, batch_size=2, shuffle=True) ``` 现在我们可以使用`my_dataloader`迭代数据集中的样本了,每个迭代器返回一个包含`batch_size`个样本的元组,其中第一个元素是一个大小为`(batch_size, 3, 224, 224)`的张量,代表`batch_size`个图片,第二个元素是一个大小为`(batch_size,)`的张量,代表`batch_size`个标签。我们可以使用下面的代码来迭代数据集: ```python for x, y in my_dataloader: print(x.shape, y.shape) ``` 输出结果如下: ``` torch.Size([2, 3, 224, 224]) torch.Size([2]) torch.Size([2, 3, 224, 224]) torch.Size([2]) torch.Size([2, 3, 224, 224]) torch.Size([2]) torch.Size([2, 3, 224, 224]) torch.Size([2]) torch.Size([2, 3, 224, 224]) torch.Size([2]) ``` 这个程序演示了如何使用PyTorch中的Dataset和Dataloader来加载数据集,并迭代数据集中的样本。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值