PyTorch 动态更新 Dataset 对 DataLoader不生效

项目场景

使用 Stitcher 增强小目标检测效果。在每个 iter 训练结束后,根据小目标贡献的 loss 占总 loss 的比重,决定下一个 iter 的数据是否需要 stitch 。

环境:PyTorch 1.3

问题描述

实现上需要使用 Hook 统计 loss 占比,然后设置 Dataset 的标志位。 Dataset.__getitem__ 会首先查询标志位,决定是否返回 stitch 后的结果。运行时发现标志位设置失败, stitch 一直没有激活。

from mmcv.runner import BaseRunner, Hook, HOOKS
from mmdet.datasets.builder import DATASETS


@HOOKS.register_module()
class SmallFeedbackHook(Hook):
    def before_train_epoch(self, runner: BaseRunner):
        data_loader: DataLoader = runner.data_loader
        dataset: StitchedDataset = data_loader.dataset
        dataset.active = False
        self.dataset = dataset

    def after_train_iter(self, runner: BaseRunner):
        ratio = ...
        self.dataset.active = ratio <= self.thresh


@DATASETS.register_module()
class StitchedDataset:
    def __getitem__(self, idx: int) -> Dict[str, Any]:
        if self.active: ...
        else: ...

原因分析

在 Hook 中设置断点发现标志位设置成功,但是 Dataset.__getitem__ 打印的信息却显示标志位设置失败,于是猜测 DataLoader 在创建进程时复制了 Dataset 对象。

参考 ptrblck 提供的代码,可以证明 DataLoader 在创建进程时没有复制 Dataset 对象,而是使用了 Dataset 对象的引用

import os

import torch
from torch.utils.data import DataLoader, Dataset


class MyDataset(Dataset):
    LEN = 30
    
    def __init__(self):
        self.data = torch.zeros(self.LEN, 1)

    def __getitem__(self, index):
        print("Get item", index, "of", id(self), "at", os.getpid())
        return self.data[index]

    def __len__(self):
        return self.LEN

    def regenerate(self):
        print("Regenerate", id(self), "at", os.getpid())
        self.data = torch.ones(self.LEN, 1)


if __name__ == '__main__':
    dataset = MyDataset()
    data_loader = DataLoader(dataset, num_workers=3)

    print(dataset[0])  # output 0
    for d in data_loader:
        print(d)

    dataset.regenerate()

    print(dataset[0])  # output 1
    for d in data_loader:
        print(d)

输出

Get item 1 of 140308728889528 at 25433
tensor([0.])
Get item 1 of 140308728889528 at 31734
Get item 0 of 140308728889528 at 31727
Get item 2 of 140308728889528 at 31744
tensor([[0.]])
...
Regenerate 140308728889528 at 25433
Get item 0 of 140308728889528 at 25433
tensor([1.])
Get item 0 of 140308728889528 at 32116
Get item 1 of 140308728889528 at 32117
Get item 2 of 140308728889528 at 32118
tensor([[1.]])
...

可以看出, 25433 是主进程, 31727/31734/31744 是第一次遍历 DataLoader 创建的三个子进程, 32116/32117/32118是第二次遍历 DataLoader 创建的三个子进程。所有进程都使用 140308728889528 的 MyDataset 对象, MyDataset.regenerate 也生效了。但是如果在一次遍历的过程中调用 MyDataset.regenerate 就无法生效。

if __name__ == '__main__':
    dataset = MyDataset()
    data_loader = DataLoader(dataset, num_workers=3)

    for d in data_loader:
        print(d)
        dataset.regenerate()

输出

Get item 1 of 139961664745768 at 12979
Get item 2 of 139961664745768 at 12982
Get item 0 of 139961664745768 at 12977
Get item 5 of 139961664745768 at 12982
Get item 3 of 139961664745768 at 12977
Get item 4 of 139961664745768 at 12979
Get item 6 of 139961664745768 at 12977
tensor([[0.]])
Regenerate 139961664745768 at 9793
Get item 7 of 139961664745768 at 12979
tensor([[0.]])
Regenerate 139961664745768 at 9793
Get item 8 of 139961664745768 at 12982
tensor([[0.]])
...

目前猜测,如果对 Dataset 的更改发生在子进程创建之前,更改就能生效,否则不生效。

解决方案

目前没有较好的解决方案,只能将 DataLoader.__init__num_workers 设置为 0 。这样可以在主进程中读取数据,避免创建子进程,但同时会减慢训练速度。

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

LutingWang

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值