pytorch __getitem__ 返回值

在pytorch中若是使用自定义数据集,需要定义Dataset类,并覆盖父类的__len__和__getitem__函数
举个例子,返回常规的数据对x, y 也可以是多个x,y 比如小样本学习中需要query support对就是两个x,两个y

class MyDataset(Dataset):
   '''
   定义相关数据
   '''
   def __len__(self):
       return len(self.x_data)
   def __getitem__(self, idx):
       '''
       相关处理
       '''
       return x, y

但是在 __getitem__中也可以返回字典类型的数据 , 例如

	def __getitem__(self, idx)
	'''
	省略
	'''
    batch = {'query_img': query_img,
             'query_mask': query_mask,
             'query_name': query_name,
             'query_ignore_idx': query_ignore_idx,

             'support_imgs': support_imgs,
             'support_masks': support_masks,
             'support_names': support_names,
             'support_ignore_idxs': support_ignore_idxs,

             'class_id': torch.tensor(class_sample)}

    return batch

下面解释一下为什么可以返回字典.
通常当我们定义好Dataset并实例化dataset之后,会实例化一个DataLoader并将dataset传入其中,DataLoader的作用是拼接多个__getitem__获得的数据,返回一个batch的数据,在实例化DataLoader的时候有一个参数是collate_fn,它用来定义数据batch拼接方式

#参数解释
        collate_fn (callable, optional): merges a list of samples to form a
            mini-batch of Tensor(s).  Used when using batched loading from a
            map-style dataset.

再来看一下默认的collate_fn函数是如何定义的

def default_collate(batch):
    r"""Puts each data field into a tensor with outer dimension batch size"""

    elem = batch[0]
    elem_type = type(elem)
    if isinstance(elem, torch.Tensor):
        out = None
        if torch.utils.data.get_worker_info() is not None:
            # If we're in a background process, concatenate directly into a
            # shared memory tensor to avoid an extra copy
            numel = sum([x.numel() for x in batch])
            storage = elem.storage()._new_shared(numel)
            out = elem.new(storage)
        return torch.stack(batch, 0, out=out)
    elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
            and elem_type.__name__ != 'string_':
        if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
            # array of string classes and object
            if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
                raise TypeError(default_collate_err_msg_format.format(elem.dtype))

            return default_collate([torch.as_tensor(b) for b in batch])
        elif elem.shape == ():  # scalars
            return torch.as_tensor(batch)
    elif isinstance(elem, float):
        return torch.tensor(batch, dtype=torch.float64)
    elif isinstance(elem, int):
        return torch.tensor(batch)
    elif isinstance(elem, string_classes):
        return batch
    elif isinstance(elem, collections.abc.Mapping):
        return {key: default_collate([d[key] for d in batch]) for key in elem}
    elif isinstance(elem, tuple) and hasattr(elem, '_fields'):  # namedtuple
        return elem_type(*(default_collate(samples) for samples in zip(*batch)))
    elif isinstance(elem, collections.abc.Sequence):
        # check to make sure that the elements in batch have consistent size
        it = iter(batch)
        elem_size = len(next(it))
        if not all(len(elem) == elem_size for elem in it):
            raise RuntimeError('each element in list of batch should be of equal size')
        transposed = zip(*batch)
        return [default_collate(samples) for samples in transposed]

    raise TypeError(default_collate_err_msg_format.format(elem_type))

起作用的应该是是这一行

 elif isinstance(elem, collections.abc.Mapping):
        return {key: default_collate([d[key] for d in batch]) for key in elem}
	其中elem是batch中的第一个元素,用列表循环式把batch中所有相同key的数据添加到同一个key的[]中

再来看一下collate_fn被调用的地方

class _MapDatasetFetcher(_BaseDatasetFetcher):
    def __init__(self, dataset, auto_collation, collate_fn, drop_last):
        super(_MapDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last)

    def fetch(self, possibly_batched_index):
        if self.auto_collation:
            data = [self.dataset[idx] for idx in possibly_batched_index]
        else:
            data = self.dataset[possibly_batched_index]
        return self.collate_fn(data)
  • 7
    点赞
  • 16
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值