pytorch加载数据集(from torch.utils.data)

本文介绍了PyTorch中数据加载的三种方式:1) 自定义 Dataset 类,覆盖 __init__、__getitem__ 和 __len__ 方法;2) 使用 TensorDataset 将张量组合成数据集;3) 利用 DataLoader 进行批量加载和并行处理,支持洗牌和多进程。示例代码展示了如何在训练过程中应用这些工具。
摘要由CSDN通过智能技术生成

方式一:

**torch.utils.data.Dataset(*args, kwds)
数据集的抽象类,从键到数据样本的映射的所有数据集都应该对其进行子类化。 所有子类都应该覆盖 getitem(),支持获取给定键的数据样本。 子类还可以选择性地覆盖 len(),它有望通过许多 Sampler 实现和 DataLoader 的默认选项返回数据集的大小

from torch.utils import data 

class MyDataset(data.Dataset):
    # 构造函数带有默认参数
    def __init__(self):
    
    def __getitem__(self, index):
        return 
    def __len__(self):
        return 

train_data = MyDataset()
trainloader = data.DataLoader(train_data, shuffle=True, num_workers = 2, batch_size = batch_size)

方式二:

*torch.utils.data.TensorDataset(tensors)

def synthetic_data(w, b, num_examples):
    x = torch.normal(0, 1, (num_examples, len(w)))
    y = torch.matmul(x, w) + b 
    y += torch.normal(0, 0.001, y.shape)
    return x, y.reshape((-1, 1))

true_w = torch.tensor([2, -3.4])
true_b = 4.2

features, labels = synthetic_data(true_w, true_b, 1000)

dataset = data.TensorDataset(*data_arrays)
trainloader = data.DataLoader(dataset, batch_size, shuffle=is_train)

方式三:

*torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None, generator=None, , prefetch_factor=2, persistent_workers=False)

dataset = VOCDetection(root=dataset_root, transform = SSDAugmentation(cfg['min_dim'], MEANS))  
data_loader = torch.utils.data.DataLoader(dataset, batch_size, num_workers=0, shuffle=True, 
                                                                                       collate_fn=detection_collate, pin_memory=True)

def detection_collate(batch):
    targets = []
    imgs = []
    img_ids = []
    for sample in batch:
        imgs.append(sample[0])
        targets.append(torch.FloatTensor(sample[1]))
        img_ids.append(sample[2])
    return torch.stack(imgs, 0), targets, img_ids
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值