import io
import os
from gluonnlp.data.dataset import Dataset, SimpleDataset
class Dataset(object):
"""Abstract dataset class. All datasets should have this interface.
Subclasses need to override `__getitem__`, which returns the i-th
element, and `__len__`, which returns the total number elements.
.. note:: An mxnet or numpy array can be directly used as a dataset.
"""
def __getitem__(self, idx):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
def filter(self, fn):
pass
def shard(self, num_shards, index):
pass
def take(self, count):
pass
def sample(self, sampler):
pass
def transform(self, fn, lazy=True):
"""Returns a new dataset with each sample transformed by the
transformer function `fn`.
对每一个样本用fn函数处理后,返回一个新的dataset
Parameters
----------
fn : callable #处理每一个样本的函数
A transformer function that takes a sample as input and
returns the transformed sample.
lazy : bool, default True #是否延迟处理样本的dataset。
If False, transforms all samples at once. Otherwise,
transforms each sample on demand. Note that if `fn`
is stochastic, you must set lazy to True or you will
get the same result on all epochs.
Returns
-------
Dataset
The transformed dataset.
"""
trans = _LazyTransformDataset(self, fn)#这里的self是SimpleDataset实例
if lazy:
return trans
return SimpleDataset([i for i in trans])
def transform_first(self, fn, lazy=True):
"""Returns a new dataset with the first element of each sample
transformed by the transformer function `fn`.
This is useful, for example, when you only want to transform data
while keeping label as is.
Parameters
----------
fn : callable
A transformer function that takes the first elemtn of a sample
as input and returns the transformed element.
lazy : bool, default True
If False, transforms all samples at once. Otherwise,
transforms each sample on demand. Note that if `fn`
is stochastic, you must set lazy to True or you will
get the same result on all epochs.
Returns
-------
Dataset
The transformed dataset.
"""
return self.transform(_TransformFirstClosure(fn), lazy)
class _LazyTransformDataset(Dataset):
"""Lazily transformed dataset."""
def __init__(self, data, fn):
self._data = data #这里的data可以是列表也可以是一个SimpleDataset实例。
self._fn = fn
def __len__(self):
return len(self._data)
def __getitem__(self, idx):
item = self._data[idx]
if isinstance(item, tuple):#如果每个样本是个元组,解包成多个位置参数再用fn函数进行处理
return self._fn(*item)
return self._fn(item)
class _TransformFirstClosure(object):
"""Use callable object instead of nested function, it can be pickled."""
def __init__(self, fn):
self._fn = fn
def __call__(self, x, *args):
if args:
return (self._fn(x),) + args
return self._fn(x)
if __name__ == '__main__':
a = SimpleDataset([(1,1),(2,2)])
a1= a.transform(lambda x,y: x+y, lazy=True)
print(a1._data)
print(a1._data._data)#lazy=True的缘故
print((a1[0]))#取值就不能lazy了
print(list(a1))
a2 = a.transform_first(lambda x:x+1,lazy=False)
print(list(a2))
结果
<mxnet.gluon.data.dataset.SimpleDataset object at 0x00000205FE0D8BC8>
[(1, 1), (2, 2)]
2
[2, 4]
[(2, 1), (3, 2)]