ConcatDataset
- 如果你有两个Dataset需要同时使用,那么就需要同时读取两个dataset,就会用到ConcatDataset。
- 例如在半监督任务中,需要分别读取label data和unlabel data,那么这两部分需要分别读取,但是要同时一起使用。
class ConcatDataset(Dataset[T_co]):
r"""Dataset as a concatenation of multiple datasets.
This class is useful to assemble different existing datasets.
Args:
datasets (sequence): List of datasets to be concatenated
"""
datasets: List[Dataset[T_co]]
cumulative_sizes: List[int] # 保存每个dataset的长度的list, [len(D1), len(D1)+len(D2), ....]
@staticmethod
def cumsum(sequence):
# 计算len(dataset),结果是list, [len(D1), len(D1)+len(D2), ....]
r, s = [], 0
for e in sequence:
l = len(e)
r.append(l + s)
s += l
return r
def __init__(self, datasets: Iterable[Dataset]) -> None:
super(ConcatDataset, self).__init__()
# Cannot verify that datasets is Sized
assert len(datasets) > 0, 'datasets should not be an empty iterable' # type: ignore[arg-type]
self.datasets = list(datasets)
for d in self.datasets:
assert not isinstance(d, IterableDataset), "ConcatDataset does not support IterableDataset"
self.cumulative_sizes = self.cumsum(self.datasets)
def __len__(self):
return self.cumulative_sizes[-1]
def __getitem__(self, idx):
# 在每次取idx时,都会判断该idx是哪个数据集中的数据,从而进行提取
if idx < 0:
if -idx > len(self):
raise ValueError("absolute value of index should not exceed dataset length")
idx = len(self) + idx
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) # 找到idx对应的dataset id
if dataset_idx == 0:
sample_idx = idx
else:
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] # 减去前边的len(D), 从而保证idx是从下一个dataset的第0开始的
return self.datasets[dataset_idx][sample_idx] # 取出对应dataset的index的item
@property
def cummulative_sizes(self):
warnings.warn("cummulative_sizes attribute is renamed to "
"cumulative_sizes", DeprecationWarning, stacklevel=2)
return self.cumulative_sizes
一般在定义时,如下定义最外层Dataset即可:
class SemiDataset(ConcatDataset):
"""Wrapper for semisupervised od."""
def __init__(self, sup: dict, unsup: dict, **kwargs):
super().__init__([build_dataset(sup), build_dataset(unsup)])
@property
def sup(self):
return self.datasets[0]
@property
def unsup(self):
return self.datasets[1]