Dataset的简单构建

我们自定义的Dataset类必须要实现:

class dataset(Dataset):
    def __init__(self, corpus_path, sentence_max_length):
        pass

    def __getitem__(self, index):
        pass

    def __len__(self):
        pass

可以看到自定义的dataset必须要有len()方法,和下标索引方法

所以我们并不是构建一个dataloader,必须要自定义一个Dataset类。

只要我们传入的数据集支持len,下标访问,且每次访问返回一个数据和标签就可以了,例如列表和数组。

Example:

dataset = [np.numpy(data),int(label)]
len(dataset)
data,label = dataset[index]

上面这个例子没有构建dataset类但是也满足Dataset类的所有特征,所以也可以直接加载到DataLoader中,但是可能在适配自己的模型时需要自己实现一下DataLoader中的collate_fn()函数,来对数据进行正确拼接。

至于是否需要进行自定义collate_fn ,主要看我们输入的数据是否为tensor格式, 如果内部元素是tensor格式,那么就不需要自己重新实现collate_fn, 如果内部元素不是tensor格式,就需要自己重新实现该函数。

关于如何构建collate_fn从0构建一个collate_fn函数

至于为什么,可以看下面的default_collate()源码。如果内部元素是tensor格式的话,则可以进入第一个if分枝语句,通过torch.stack()进行构建batch。 如果内部元素不是tensor格式,例如为元组或者,列表形式。那么就要进入最后一个else分枝语句块,通过执行transposed = zip(*batch),这产生的结果可能不是我们所期望的。

DataLoader(dataset: Dataset[T_co], batch_size: Optional[int] = 1,shuffle: bool = False, 
           sampler: Optional[Sampler[int]] = None,batch_sampler: Optional[Sampler[Sequence[int]]] = None,
           num_workers: int = 0, collate_fn: Optional[_collate_fn_t] = None,
           pin_memory: bool = False, drop_last: bool = False,
           timeout: float = 0, worker_init_fn: Optional[_worker_init_fn_t] = None,
           multiprocessing_context=None, generator=None,
           *, prefetch_factor: int = 2,
           persistent_workers: bool = False)
def default_collate(batch):
    r"""Puts each data field into a tensor with outer dimension batch size"""

    elem = batch[0]
    elem_type = type(elem)
    if isinstance(elem, torch.Tensor):
        out = None
        if torch.utils.data.get_worker_info() is not None:
            # If we're in a background process, concatenate directly into a
            # shared memory tensor to avoid an extra copy
            numel = sum(x.numel() for x in batch)
            storage = elem.storage()._new_shared(numel)
            out = elem.new(storage)
           return torch.stack(batch, 0, out=out)
    elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
            and elem_type.__name__ != 'string_':
        if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
            # array of string classes and object
            if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
                raise TypeError(default_collate_err_msg_format.format(elem.dtype))

            return default_collate([torch.as_tensor(b) for b in batch])
        elif elem.shape == ():  # scalars
            return torch.as_tensor(batch)
    elif isinstance(elem, float):
        return torch.tensor(batch, dtype=torch.float64)
    elif isinstance(elem, int):
        return torch.tensor(batch)
    elif isinstance(elem, string_classes):
        return batch
    elif isinstance(elem, collections.abc.Mapping):
        return {key: default_collate([d[key] for d in batch]) for key in elem}
    elif isinstance(elem, tuple) and hasattr(elem, '_fields'):  # namedtuple
        return elem_type(*(default_collate(samples) for samples in zip(*batch)))
    elif isinstance(elem, collections.abc.Sequence):
        # check to make sure that the elements in batch have consistent size
        it = iter(batch)
        elem_size = len(next(it))
        if not all(len(elem) == elem_size for elem in it):
            raise RuntimeError('each element in list of batch should be of equal size')
        transposed = zip(*batch)
        return [default_collate(samples) for samples in transposed]

    raise TypeError(default_collate_err_msg_format.format(elem_type))
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值