【Pytorch】自定义Dataset类按照index去除train_set或test_set中指定数据

在做深度学习相关任务时,我们一般是用 “Pytorch自带的DataLoader+自己写Dataset类” 这样的方式来加载原始或已经预处理过的数据,但是这样的方式是非常的固定的,即输入给Dataset多少数据,最后通过Dataloader传入训练阶段的数据就有多少。因此这里就会有一个问题:假如说我们已知原始数据或预处理数据中有些样本是有问题的,但是又不想简单粗暴的在传递给Dataset类之前做数据清理(删除掉),而是想要在自己写的Dataset类中自动过滤(去除)掉这些样本的数据,应该如何实现?

举个例子,比如我传给Dataset类的预处理数据尺寸是(1345,256,256,3),也就是1345张图片,我知道第1239张图片是有问题的,不想用它来训练,想要在数据传给训练之前在Dataset类中按照1239这个index来自动去除掉这个样本。

具体做法的思路要感谢Pytorch自定义Dataset和DataLoader去除不存在和空的数据,但是原作者的code只适用于Dataset返回两个值的情况,并且在使用时还要自己在写一个dataset_collate脚本,比较复杂且不通用。

下面是改进后的做法,只需两步

1. 在自己的Dataset类中,添加一段代码:

假如初始的Dataset类如下(简化版),每个样本的数据需要返回若干个结果,有img, pose, heatmap, landmark 等等,多少数量都可以。

class Dataset(Dataset):
    def __init__(self, ):
		pass
	def load_data(self, path):
       	pass
    def __len__(self):
    	pass
    def __getitem__(self, index):
        xxx
        return img, pose, heatmap, landmark 

如果想要在Dataset类中按照index指定去除数据,只需要将返回的第一个结果设置为None即可,加入以下两行代码:

        if index == 1238:   # index为1238时对应第1239张图片
            img = None

注意,并不是非得设置img=None才可以,而是本例中返回的第一个值是img,所以才设置img=None完整如下:

class Dataset(Dataset):
    def __init__(self, ):
		pass
	def load_data(self, path):
       	pass
    def __len__(self):
    	pass
    def __getitem__(self, index):
        xxx
        if index == 1238:   # index为1238时对应第1239张图片
            img = None        
        return img, pose, heatmap, landmark 
2. 修改自身python环境安装的your_python_envs_path/Lib/site-packages/torch/utils/data/_utils/collate.py文件中修改default_collate函数

做了以上改进后还不行,如果使用原始的default_collate()函数,会报错,需要在原始的代码中加入两行代码:

    if isinstance(batch, list):
        batch = [ i for i in batch if i[0] is not None]

如果使用VScode的话,可以创建一个脚本,impor torch后输入:torch.utils.data.DataLoader,然后按ctrl键的同时鼠标点击DataLoader可以找到DataLoader的文件夹,然后在文件夹内进入_utils文件夹就可以看到collate.py。

原始的default_collate()函数代码如下:

def default_collate(batch):
    r"""Puts each data field into a tensor with outer dimension batch size"""
    elem = batch[0]
    elem_type = type(elem)
    if isinstance(elem, torch.Tensor):
        out = None
        if torch.utils.data.get_worker_info() is not None:
            # If we're in a background process, concatenate directly into a
            # shared memory tensor to avoid an extra copy
            numel = sum(x.numel() for x in batch)
            storage = elem.storage()._new_shared(numel)
            out = elem.new(storage)
        return torch.stack(batch, 0, out=out)
    elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
            and elem_type.__name__ != 'string_':
        if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
            # array of string classes and object
            if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
                raise TypeError(default_collate_err_msg_format.format(elem.dtype))

            return default_collate([torch.as_tensor(b) for b in batch])
        elif elem.shape == ():  # scalars
            return torch.as_tensor(batch)
    elif isinstance(elem, float):
        return torch.tensor(batch, dtype=torch.float64)
    elif isinstance(elem, int):
        return torch.tensor(batch)
    elif isinstance(elem, string_classes):
        return batch
    elif isinstance(elem, collections.abc.Mapping):
        return {key: default_collate([d[key] for d in batch]) for key in elem}
    elif isinstance(elem, tuple) and hasattr(elem, '_fields'):  # namedtuple
        return elem_type(*(default_collate(samples) for samples in zip(*batch)))
    elif isinstance(elem, collections.abc.Sequence):
        # check to make sure that the elements in batch have consistent size
        it = iter(batch)
        elem_size = len(next(it))
        if not all(len(elem) == elem_size for elem in it):
            raise RuntimeError('each element in list of batch should be of equal size')
        transposed = zip(*batch)
        return [default_collate(samples) for samples in transposed]

    raise TypeError(default_collate_err_msg_format.format(elem_type))

增加代码后的default_collate()函数代码如下:

def default_collate(batch):
    r"""Puts each data field into a tensor with outer dimension batch size"""

    if isinstance(batch, list):
        batch = [ i for i in batch if i[0] is not None]

    elem = batch[0]
    elem_type = type(elem)
    if isinstance(elem, torch.Tensor):
        out = None
        if torch.utils.data.get_worker_info() is not None:
            # If we're in a background process, concatenate directly into a
            # shared memory tensor to avoid an extra copy
            numel = sum(x.numel() for x in batch)
            storage = elem.storage()._new_shared(numel)
            out = elem.new(storage)
        return torch.stack(batch, 0, out=out)
    elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
            and elem_type.__name__ != 'string_':
        if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
            # array of string classes and object
            if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
                raise TypeError(default_collate_err_msg_format.format(elem.dtype))

            return default_collate([torch.as_tensor(b) for b in batch])
        elif elem.shape == ():  # scalars
            return torch.as_tensor(batch)
    elif isinstance(elem, float):
        return torch.tensor(batch, dtype=torch.float64)
    elif isinstance(elem, int):
        return torch.tensor(batch)
    elif isinstance(elem, string_classes):
        return batch
    elif isinstance(elem, collections.abc.Mapping):
        return {key: default_collate([d[key] for d in batch]) for key in elem}
    elif isinstance(elem, tuple) and hasattr(elem, '_fields'):  # namedtuple
        return elem_type(*(default_collate(samples) for samples in zip(*batch)))
    elif isinstance(elem, collections.abc.Sequence):
        # check to make sure that the elements in batch have consistent size
        it = iter(batch)
        elem_size = len(next(it))
        if not all(len(elem) == elem_size for elem in it):
            raise RuntimeError('each element in list of batch should be of equal size')
        transposed = zip(*batch)
        return [default_collate(samples) for samples in transposed]

    raise TypeError(default_collate_err_msg_format.format(elem_type))

执行完以上两步后,就可以按照index在自定义的Dataset类中去除训练集或测试集中的数据。

  • 1
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值