pytorch dataloader_一分钟学 Trick: PyTorch 动态更新 DataLoader

[很长一段时间,文章的收藏量都是点赞量的三倍多,大家如果感觉有用,请点个“赞”来资瓷下哇 ~]

f631897a4d34ee7026ce2116c6ed6a4b.png

PyTorch 动态更新 DataLoader

我们知道 PyTorch 里面的 DataLoader 是多进程的,用起来非常方便。但是多线程带来问题就是,会有额外的内存消耗。

最近遇到了一个需求,需要不断的更新 DataLoader 。当然,最简单的方法还是每次需要更新 DataLoader 的时候,重新新建一个 DataLoader,但是这种方法的代价是很大的,尤其是当我们需要频繁的更新 DataLoader 的时候。

举个例子,我们只想取 Dataset 中的一部分,所以可以使用 SubsetRandomSampler

from torch.utils.data import DataLoader, TensorDataset, SubsetRandomSampler, BatchSampler
import torch

candidate = [1]
dataset = TensorDataset(torch.tensor(list(range(10))))
dataloader = DataLoader(dataset, batch_size=1, sampler=SubsetRandomSampler(candidate))

for idx, data in enumerate(dataloader):
    print(data)

输出的结果肯定是只取第二个数据,也就是

[tensor([1])]

但是假如我们在不重新新建一个 DataLoader 的情况下,更新这个 SubsetRandomSampler ,怎么办?

一个简单的想法是,我们能否直接更新 dataloader.sampler 以及 dataloader.batch_sampler

candidate = [2]
dataloader.sampler = SubsetRandomSampler(candidate)
dataloader.batch_sampler = BatchSampler(SubsetRandomSampler(candidate), batch_size=1, drop_last=True)

遗憾的是,PyTorch 并不支持这个操作。

ba4703062fcc2b0828f4dc3d474b72e0.png

本文给出一个非常简单的解决方法。

In-place 地改变 candidate 这个变量就行了。

from torch.utils.data import DataLoader, TensorDataset
import torch

candidate = [1,2]
dataset = TensorDataset(torch.tensor(list(range(10))))
dataloader = DataLoader(dataset, batch_size=1, sampler=SubsetRandomSampler(candidate))

for idx, data in enumerate(dataloader):
    print(data)

[candidate.pop() for i in range(len(candidate))]
candidate.extend([3,4])

for idx, data in enumerate(dataloader):
    print(data)

我们来看下效果:

[tensor([1])]
[tensor([2])]
[tensor([4])]
[tensor([3])]

可以看到这个方法是成功的。

中间两行:

[candidate.pop() for i in range(len(candidate))]
candidate.extend([3,4])

是先 in-place 地把这个 list 清空,然后再 in-place 地更新这个 list。

注意,这里的关键是 in-place!

如果你直接对 candidate 进行赋值,是无法达到修改效果的。

from torch.utils.data import DataLoader, TensorDataset
import torch

candidate = [1,2]
dataset = TensorDataset(torch.tensor(list(range(10))))
dataloader = DataLoader(dataset, batch_size=1, sampler=SubsetRandomSampler(candidate))

for idx, data in enumerate(dataloader):
    print(data)

candidate = [3,4]

for idx, data in enumerate(dataloader):
    print(data)

无法达到效果:

[tensor([1])]
[tensor([2])]
[tensor([2])]
[tensor([1])]
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值