方式一:
**torch.utils.data.Dataset(*args, kwds)
数据集的抽象类,从键到数据样本的映射的所有数据集都应该对其进行子类化。 所有子类都应该覆盖 getitem(),支持获取给定键的数据样本。 子类还可以选择性地覆盖 len(),它有望通过许多 Sampler 实现和 DataLoader 的默认选项返回数据集的大小
from torch.utils import data
class MyDataset(data.Dataset):
# 构造函数带有默认参数
def __init__(self):
def __getitem__(self, index):
return
def __len__(self):
return
train_data = MyDataset()
trainloader = data.DataLoader(train_data, shuffle=True, num_workers = 2, batch_size = batch_size)
方式二:
*torch.utils.data.TensorDataset(tensors)
def synthetic_data(w, b, num_examples):
x = torch.normal(0, 1, (num_examples, len(w)))
y = torch.matmul(x, w) + b
y += torch.normal(0, 0.001, y.shape)
return x, y.reshape((-1, 1))
true_w = torch.tensor([2, -3.4])
true_b = 4.2
features, labels = synthetic_data(true_w, true_b, 1000)
dataset = data.TensorDataset(*data_arrays)
trainloader = data.DataLoader(dataset, batch_size, shuffle=is_train)
方式三:
*torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None, generator=None, , prefetch_factor=2, persistent_workers=False)
dataset = VOCDetection(root=dataset_root, transform = SSDAugmentation(cfg['min_dim'], MEANS))
data_loader = torch.utils.data.DataLoader(dataset, batch_size, num_workers=0, shuffle=True,
collate_fn=detection_collate, pin_memory=True)
def detection_collate(batch):
targets = []
imgs = []
img_ids = []
for sample in batch:
imgs.append(sample[0])
targets.append(torch.FloatTensor(sample[1]))
img_ids.append(sample[2])
return torch.stack(imgs, 0), targets, img_ids