torch.utils.data.Dataset
class MyData(torch.utils.data.Dataset):
def __init__(self, dt, lb):
self.dt = dt
self.lb = lb
def __len__(self):
return len(self.dt)
def __getitem__(self, index):
return self.lb[index], np.array(self.dt[index])
重写了__len__()、和__getitem__()方法。
torch.utils.data.TensorDataset:
class TensorDataset(Dataset):
"""Dataset wrapping tensors.
Each sample will be retrieved by indexing tensors along the first dimension.
Arguments:
*tensors (Tensor): tensors that have the same size of the first dimension.
"""
def __init__(self, *tensors):
assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors)
self.tensors = tensors
def __getitem__(self, index):
return tuple(tensor[index] for tensor in self.tensors)
def __len__(self):
return self.tensors[0].size(0)
从上图可知,TensorDataset是继承了Dataset