Pytorch中iter(dataloader)的使用

本文介绍了PyTorch中的DataLoader如何作为可迭代对象工作,通过iter()和enumerate()访问数据集。示例展示了如何加载MNIST数据集,并以批次方式处理图像和标签。在使用enumerate()时,注意imgs和labels的顺序,它们分别代表了图像数据和对应的标签值。
摘要由CSDN通过智能技术生成

dataloader本质上是一个可迭代对象,可以使用iter()进行访问,采用iter(dataloader)返回的是一个迭代器,然后可以使用next()访问。
也可以使用enumerate(dataloader)的形式访问。
下面举例说明:

transformation = transforms.Compose([
    transforms.ToTensor()
])

train_ds = datasets.MNIST("./data", train=True, transform=transformation, download=True)

test_ds = datasets.MNIST("./data", train=False, transform=transformation, download=True)

train_dl = DataLoader(train_ds, batch_size=64, shuffle=True)
test_dl = DataLoader(test_ds, batch_size=256)
#imgs, labels = next(iter(train_dl))

for labels, imgs in enumerate(train_dl): #如果imgs在前,labels在后,那么imgs将是标签形式,labels才是图片转化0~1之间的值。
    print("imgs:\t", imgs)
    print("labels:\t", labels)
labels:	 3
imgs:	 [tensor([[[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        ...,


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]]]), tensor([5, 1, 7, 6, 7, 9, 6, 4, 0, 4, 0, 4, 4, 4, 2, 7, 5, 2, 9, 2, 1, 9, 1, 8,
        2, 6, 8, 0, 1, 6, 1, 0, 3, 6, 6, 2, 5, 1, 3, 4, 4, 1, 8, 4, 8, 1, 2, 5,
        2, 0, 1, 3, 6, 6, 0, 1, 7, 6, 0, 8, 3, 7, 1, 6])]

iter(dataloader)访问时,imgs在前,labels在后,分别表示:图像转换0~1之间的值,labels为标签值。并且imgs和labels是按批次进行输入的。

transformation = transforms.Compose([
    transforms.ToTensor()
])

train_ds = datasets.MNIST("./data", train=True, transform=transformation, download=True)

test_ds = datasets.MNIST("./data", train=False, transform=transformation, download=True)

train_dl = DataLoader(train_ds, batch_size=64, shuffle=True)
test_dl = DataLoader(test_ds, batch_size=256)
#imgs, labels = next(iter(train_dl))
"""
for labels, imgs in enumerate(train_dl):
    print("imgs:\t", imgs)
    print("labels:\t", labels)
"""
for imgs, labels in iter(train_dl):
    print("imgs:\t", imgs)
    print("label:\t", labels)
imgs:	 tensor([[[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        ...,


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]]])
label:	 tensor([6, 1, 0, 8, 6, 7, 8, 1, 3, 4, 8, 5, 8, 9, 7, 2, 9, 3, 0, 6, 1, 1, 4, 6,
        0, 6, 7, 9, 3, 7, 1, 3, 5, 2, 7, 1, 1, 0, 3, 0, 1, 0, 8, 7, 5, 1, 5, 6,
        3, 3, 1, 3, 8, 6, 8, 7, 6, 3, 8, 3, 1, 0, 2, 7])
  • 28
    点赞
  • 88
    收藏
    觉得还不错? 一键收藏
  • 10
    评论
### 回答1: 好的,下面给出一个基于 PyTorch 的元学习 DataLoader 的示例代码: ```python import torch from torch.utils.data import Dataset, DataLoader class MetaDataset(Dataset): def __init__(self, data): self.data = data def __len__(self): return len(self.data) def __getitem__(self, idx): return self.data[idx] class MetaDataloader(DataLoader): def __init__(self, dataset, n_way, k_shot, q_query, batch_size=1, shuffle=True): self.n_way = n_way self.k_shot = k_shot self.q_query = q_query self.dataset = dataset self.classes = list(set([data[1] for data in dataset])) self.class_dict = {c: [data for data in dataset if data[1] == c] for c in self.classes} self.batch_size = batch_size self.shuffle = shuffle super().__init__(dataset, batch_size=batch_size, shuffle=shuffle) def __iter__(self): for _ in range(self.batch_size): support_set = [] query_set = [] # Sample n_way classes classes = torch.randperm(len(self.classes))[:self.n_way] for c in classes: # Sample k_shot + q_query examples per class samples = torch.randperm(len(self.class_dict[self.classes[c]]))[:self.k_shot + self.q_query] support_set.extend([(self.class_dict[self.classes[c]][i][0], c) for i in samples[:self.k_shot]]) query_set.extend([(self.class_dict[self.classes[c]][i][0], c) for i in samples[self.k_shot:]]) if self.shuffle: # Shuffle support set and query set support_set = torch.stack(support_set).split(self.k_shot) query_set = torch.stack(query_set).split(self.q_query) zipped = list(zip(support_set, query_set)) torch.shuffle(zipped) support_set, query_set = zip(*zipped) support_set = torch.cat(support_set) query_set = torch.cat(query_set) else: support_set = torch.stack(support_set).split(self.k_shot) query_set = torch.stack(query_set).split(self.q_query) yield support_set, query_set ``` 这个 MetaDataloader 类继承自 PyTorchDataLoader 类,并且添加了元学习的逻辑。它接受一个数据集、一个 n_way 参数(表示每个 batch 包含几个类别)、一个 k_shot 参数(表示每个类别包含几个样本作为支持集)、一个 q_query 参数(表示每个类别包含几个样本作为查询集)、以及其他 DataLoader 支持的参数。 在每个 epoch ,MetaDataloader 会随机选择 n_way 个类别,并从每个类别随机选择 k_shot + q_query 个样本。它会把前 k_shot 个样本作为支持集,后 q_query 个样本作为查询集。然后,它会将这些支持集和查询集打包成一个 tuple,返回给调用者。 如果 shuffle 参数为 True,MetaDataloader 会将每个支持集和查询集打包成一个 tuple,并随机打乱它们的顺序。如果 shuffle 参数为 False,MetaDataloader 会保持它们的顺序不变。 你可以按照以下方式使用这个 MetaDataloader 类: ```python # 创建一个元数据集 data = [(torch.randn(3, 224, 224), i // 5) for i in range(100)] # 创建一个 MetaDataloader meta_dataloader = MetaDataloader(MetaDataset(data), n_way=5, k_shot=1, q_query=1, batch_size=2) # 使用 MetaDataloader 进行训练 for support_set, query_set in meta_dataloader: # 在这里进行训练 pass ``` 这个示例代码,我们创建了一个元数据集,包含 100 个样本,每个样本由一个大小为 (3, 224, 224) 的张量和一个从 0 到 4 的标签组成。然后,我们创建了一个 MetaDataloader,它每个 batch 包含 5 个类别,每个类别包含 1 个支持集和 1 个查询集,每个 batch 包含 2 个这样的元素。最后,我们使用这个 MetaDataloader 进行训练。在训练过程,我们会得到一个支持集和一个查询集的 tuple,可以在其进行模型的训练和推理。 ### 回答2: 元学习是一种能够快速学习和适应新任务的机器学习算法,其核心思想是通过在多个任务上进行训练,使模型能够从过去的经验提取出通用的知识,进而在面对新任务时能够更快速地适应和学习。 基于pytorch构建元学习dataloader需要以下步骤: 1. 创建一个自定义的数据集类(Dataset):该类需要继承自torch.utils.data.Dataset,并实现__len__和__getitem__方法。在__getitem__方法,根据输入的index获取一个任务样本(例如样本的输入和标签),并将其返回。 2. 创建一个元学习数据集类(MetaDataset):该类也需要继承自torch.utils.data.Dataset,并实现__len__和__getitem__方法。在__getitem__方法,根据输入的index获取一个元任务样本,然后根据该样本的描述信息(例如任务类别)加载对应的任务数据集,并将其返回。 3. 创建一个元学习dataloader类(MetaDataloader):该类需要实现能够高效加载和处理元任务数据集的功能。可以使用torch.utils.data.DataLoader来处理任务数据集的加载,根据需要设置batch size、shuffle等参数。 4. 基于以上的数据集和dataloader类,可以进行元学习模型的训练和测试。在训练过程,首先从元学习dataloader获取一个元任务样本,然后利用该样本的任务数据集进行模型的训练。在测试过程,也可以通过元学习dataloader提供的接口来获取测试数据集。 总之,基于pytorch的元学习dataloader的实现需要创建自定义的数据集类、元学习数据集类和元学习dataloader类,并在训练和测试过程使用它们来读取和处理元任务数据。通过这种方式,可以方便地加载和使用元任务数据集,从而实现元学习算法的训练和测试。
评论 10
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值