【pytorch】DataLoader 和 Dataset 的使用

加载顺序

pytorch中加载数据的顺序是:
①创建一个dataset对象
②创建一个dataloader对象
③循环调用dataloader对象,获取data,label数据拿到模型中去训练

Dataset

你需要自己定义一个class继承父类Dataset,其中至少需要重写以下3个函数:
①__init__:传入数据,或者加载数据
②__len__:返回这个数据集一共有多少个item
③__getitem__: 返回一条训练数据,并将其转换成tensor

示例代码:

class MyData(Dataset):

  def __init__(self, x_patches, y_patches, transform = None):
    self.y_patches = y_patches
    self.x_patches = x_patches
    self.transform = transform

  def __len__(self):
    return len(self.y_patches)

  def __getitem__(self, idx):
    y_image = self.y_patches[idx]
    x_image = self.x_patches[idx]

    y_image = np.asarray(y_image)
    x_image = np.asarray(x_image)

    y_image = Image.fromarray(y_image.astype(np.uint8))
    x_image = Image.fromarray(x_image.astype(np.uint8))

    if self.transform:
       y_image = self.transform(y_image)
       x_image = self.transform(x_image)

    return x_image, y_image

DataLoader

参数:
dataset:传入的数据
shuffle = True:是否打乱数据
collate_fn:这个参数可以自己操作每个batch的数据 参考:Pytorch中DataLoader的使用_kahuifu的博客-CSDN博客

示例代码:

dataset = MyData(x_patches, y_patches, transform=transforms.Compose(
            [transforms.ToTensor(), 
             transforms.Normalize([0.5], [0.5])]))

bs = 16
data_loader = DataLoader(dataset, batch_size=bs, shuffle=True)
num_batches = len(data_loader)

调用DateLoader

最后循环调用dataloader ,拿到数据放入模型进行训练

for n_batch, (x_batch, y_batch) in enumerate(data_loader):

    x_data = x_batch.float().cuda()
    y_data = y_batch.float().cuda()

  • 0
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值