思路:自定义一个sampler,将采样方式传进不同的dataloader,则取出的数据一致
代码:
import torch
from torch import nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
class MyDataset_v1(Dataset):
def __init__(self):
self.data = [1, 2, 3, 4]
def __len__(self):
return len(self.data)
def __getitem__(self, item):
return self.data[item]
class MyDataset_v2(Dataset):
def __init__(self):
self.data = [1.1, 2.2, 3.3, 4.4]
def __len__(self):
return len(self.data)
def __getitem__(self, item):
return self.data[item]
class SamplerDef(object):
def __init__(self, data_source, indices):
self.data_source = data_source
self.indices = indices
def __iter__(self):
return iter(self.indices)
def __len__(self):
return len(self.data_source)
if __name__ == "__main__":
myDataset1 = MyDataset_v1()