在做深度学习相关任务时,我们一般是用 “Pytorch自带的DataLoader+自己写Dataset类” 这样的方式来加载原始或已经预处理过的数据,但是这样的方式是非常的固定的,即输入给Dataset多少数据,最后通过Dataloader传入训练阶段的数据就有多少。因此这里就会有一个问题:假如说我们已知原始数据或预处理数据中有些样本是有问题的,但是又不想简单粗暴的在传递给Dataset类之前做数据清理(删除掉),而是想要在自己写的Dataset类中自动过滤(去除)掉这些样本的数据,应该如何实现?
举个例子,比如我传给Dataset类的预处理数据尺寸是(1345,256,256,3),也就是1345张图片,我知道第1239张图片是有问题的,不想用它来训练,想要在数据传给训练之前在Dataset类中按照1239这个index来自动去除掉这个样本。
具体做法的思路要感谢Pytorch自定义Dataset和DataLoader去除不存在和空的数据,但是原作者的code只适用于Dataset返回两个值的情况,并且在使用时还要自己在写一个dataset_collate脚本,比较复杂且不通用。
下面是改进后的做法,只需两步:
1. 在自己的Dataset类中,添加一段代码:
假如初始的Dataset类如下(简化版),每个样本的数据需要返回若干个结果,有img, pose, heatmap, landmark 等等,多少数量都可以。
class Dataset(Dataset):
def __init__(self, ):
pass
def load_data(self, path):
pass
def __len__(self):
pass
def __getitem__(self, index):
xxx
return img, pose, heatmap, landmark
如果想要在Dataset类中按照index指定去除数据,只需要将返回的第一个结果设置为None
即可,加入以下两行代码:
if index == 1238: # index为1238时对应第1239张图片
img = None
注意,并不是非得设置img=None才可以,而是本例中返回的第一个值是img,所以才设置img=None
完整如下:
class Dataset(Dataset):
def __init__(self, ):
pass
def load_data(self, path):
pass
def __len__(self):
pass
def __getitem__(self, index):
xxx
if index == 1238: # index为1238时对应第1239张图片
img = None
return img, pose, heatmap, landmark
2. 修改自身python环境安装的your_python_envs_path/Lib/site-packages/torch/utils/data/_utils/collate.py文件中修改default_collate函数
做了以上改进后还不行,如果使用原始的default_collate()函数,会报错,需要在原始的代码中加入两行代码:
if isinstance(batch, list):
batch = [ i for i in batch if i[0] is not None]
如果使用VScode的话,可以创建一个脚本,impor torch后输入:torch.utils.data.DataLoader,然后按ctrl
键的同时鼠标点击DataLoader可以找到DataLoader的文件夹,然后在文件夹内进入_utils
文件夹就可以看到collate.py。
原始的default_collate()函数代码如下:
def default_collate(batch):
r"""Puts each data field into a tensor with outer dimension batch size"""
elem = batch[0]
elem_type = type(elem)
if isinstance(elem, torch.Tensor):
out = None
if torch.utils.data.get_worker_info() is not None:
# If we're in a background process, concatenate directly into a
# shared memory tensor to avoid an extra copy
numel = sum(x.numel() for x in batch)
storage = elem.storage()._new_shared(numel)
out = elem.new(storage)
return torch.stack(batch, 0, out=out)
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
and elem_type.__name__ != 'string_':
if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
# array of string classes and object
if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
raise TypeError(default_collate_err_msg_format.format(elem.dtype))
return default_collate([torch.as_tensor(b) for b in batch])
elif elem.shape == (): # scalars
return torch.as_tensor(batch)
elif isinstance(elem, float):
return torch.tensor(batch, dtype=torch.float64)
elif isinstance(elem, int):
return torch.tensor(batch)
elif isinstance(elem, string_classes):
return batch
elif isinstance(elem, collections.abc.Mapping):
return {key: default_collate([d[key] for d in batch]) for key in elem}
elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
return elem_type(*(default_collate(samples) for samples in zip(*batch)))
elif isinstance(elem, collections.abc.Sequence):
# check to make sure that the elements in batch have consistent size
it = iter(batch)
elem_size = len(next(it))
if not all(len(elem) == elem_size for elem in it):
raise RuntimeError('each element in list of batch should be of equal size')
transposed = zip(*batch)
return [default_collate(samples) for samples in transposed]
raise TypeError(default_collate_err_msg_format.format(elem_type))
增加代码后的default_collate()函数代码如下:
def default_collate(batch):
r"""Puts each data field into a tensor with outer dimension batch size"""
if isinstance(batch, list):
batch = [ i for i in batch if i[0] is not None]
elem = batch[0]
elem_type = type(elem)
if isinstance(elem, torch.Tensor):
out = None
if torch.utils.data.get_worker_info() is not None:
# If we're in a background process, concatenate directly into a
# shared memory tensor to avoid an extra copy
numel = sum(x.numel() for x in batch)
storage = elem.storage()._new_shared(numel)
out = elem.new(storage)
return torch.stack(batch, 0, out=out)
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
and elem_type.__name__ != 'string_':
if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
# array of string classes and object
if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
raise TypeError(default_collate_err_msg_format.format(elem.dtype))
return default_collate([torch.as_tensor(b) for b in batch])
elif elem.shape == (): # scalars
return torch.as_tensor(batch)
elif isinstance(elem, float):
return torch.tensor(batch, dtype=torch.float64)
elif isinstance(elem, int):
return torch.tensor(batch)
elif isinstance(elem, string_classes):
return batch
elif isinstance(elem, collections.abc.Mapping):
return {key: default_collate([d[key] for d in batch]) for key in elem}
elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
return elem_type(*(default_collate(samples) for samples in zip(*batch)))
elif isinstance(elem, collections.abc.Sequence):
# check to make sure that the elements in batch have consistent size
it = iter(batch)
elem_size = len(next(it))
if not all(len(elem) == elem_size for elem in it):
raise RuntimeError('each element in list of batch should be of equal size')
transposed = zip(*batch)
return [default_collate(samples) for samples in transposed]
raise TypeError(default_collate_err_msg_format.format(elem_type))
执行完以上两步后,就可以按照index在自定义的Dataset类中去除训练集或测试集中的数据。