DataLoader的使用

官方文档说明

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=None,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None, *, prefetch_factor=2,
           persistent_workers=False)

其中最重要的当属 dataset 一项,pytorch 支持两种类型的 dataset

  • map-style datasets
  • iterable-style datasets

对于 map-style.dataset 类型,它需要 _getitem__() and __len__() 这两个函数

下面我们以原始数据 ‘abcdefg’ 为例进行说明。注意两个函数的编写,以及 torch.utils.data.DataLoader() 的参数变化

将’abcdefg’顺序遍历

import torch
import torch.utils.data 
class ExampleDataset(torch.utils.data.Dataset):
  def __init__(self):
    self.data = "abcdefg"
  
  def __getitem__(self,idx): # if the index is idx, what will be the data?
    return self.data[idx]
  
  def __len__(self): # What is the length of the dataset
    return len(self.data)

dataset1 = ExampleDataset() # create the dataset
dataloader = torch.utils.data.DataLoader(dataset = dataset1,shuffle = True,batch_size = 1)
for datapoint in dataloader:
  print(datapoint)
  
  
------output----------
['a']
['b']
['c']
['d']
['e']
['f']
['g']

shuffle=True,进行打乱,随机取出

import torch
import torch.utils.data 
class ExampleDataset(torch.utils.data.Dataset):
  def __init__(self):
    self.data = "abcdefg"
  
  def __getitem__(self,idx): # if the index is idx, what will be the data?
    return self.data[idx]
  
  def __len__(self): # What is the length of the dataset
    return len(self.data)

dataset1 = ExampleDataset() # create the dataset
dataloader = torch.utils.data.DataLoader(dataset = dataset1,shuffle = True,batch_size = 1)
for datapoint in dataloader:
  print(datapoint)
  
------output----------------
['f']
['a']
['d']
['e']
['c']
['g']
['b']

改变batch_size

import torch
import torch.utils.data 
class ExampleDataset(torch.utils.data.Dataset):
  def __init__(self):
    self.data = "abcdefg"
  
  def __getitem__(self,idx): # if the index is idx, what will be the data?
    return self.data[idx]
  
  def __len__(self): # What is the length of the dataset
    return len(self.data)

dataset1 = ExampleDataset() # create the dataset
dataloader = torch.utils.data.DataLoader(dataset = dataset1,shuffle = True,batch_size = 2)
for datapoint in dataloader:
  print(datapoint)


-----------output-------------
['d', 'c']
['f', 'b']
['e', 'a']
['g']

改写_getitem__()and __len__()以达到自己想要的结果

import torch
import torch.utils.data 
class ExampleDataset(torch.utils.data.Dataset):
  def __init__(self):
    self.data = "abcdefg"
  
  def __getitem__(self,idx): # if the index is idx, what will be the data?
    return self.data[idx], self.data[idx].upper()
  
  def __len__(self): # What is the length of the dataset
    return len(self.data)

dataset1 = ExampleDataset() # create the dataset
dataloader = torch.utils.data.DataLoader(dataset = dataset1,shuffle = False,batch_size = 2)
for datapoint in dataloader:
  print(datapoint)


-----------output-----------
[('a', 'b'), ('A', 'B')]
[('c', 'd'), ('C', 'D')]
[('e', 'f'), ('E', 'F')]
[('g',), ('G',)]
import torch.utils.data 
class ExampleDataset(torch.utils.data.Dataset):
  def __init__(self):
    self.data = "abcdefg"
  
  def __getitem__(self,idx): # if the index is idx, what will be the data?
    if idx >= len(self.data): # if the index >= 26, return upper case letter
      return self.data[idx%7].upper()
    else: # if the index < 26, return lower case, return lower case letter
      return self.data[idx]
  
  def __len__(self): # What is the length of the dataset
    return 2 * len(self.data) # The length is now twice as large

dataset1 = ExampleDataset() # create the dataset
dataloader = torch.utils.data.DataLoader(dataset = dataset1,shuffle = False,batch_size = 2)
for datapoint in dataloader:
  print(datapoint)

-----------output------------
['a', 'b']
['c', 'd']
['e', 'f']
['g', 'A']
['B', 'C']
['D', 'E']
['F', 'G']

带有 transform 的读取本地图片的 Dataset

DATA_DIR = 'data/CIFAR-10'
DATABASE_FILE = 'database_img.txt'
DATABASE_LABEL = 'database_label.txt'

# 一般不同网络架构可能对输入的图片数据有格式要求,可以在此处做处理
# 当然是用 transforms 的操作除了满足网络输入的需求,同样还可以用作数据加强
 transformations = transforms.Compose([
        transforms.Scale(256),
        transforms.CenterCrop(224),  
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    
 dset_database = ExampleDataset(
      DATA_DIR, DATABASE_FILE, DATABASE_LABEL, transformations)
class ExampleDataset(Dataset):
    def __init__(self, data_path, img_filename, label_filename, transform=None):
        self.img_path = data_path
        self.transform = transform
        # reading img file from file
        img_filepath = os.path.join(data_path, img_filename)
        fp = open(img_filepath, 'r')
        self.img_filename = [x.strip() for x in fp]
        fp.close()
        label_filepath = os.path.join(data_path, label_filename)
        fp_label = open(label_filepath, 'r')
        labels = [int(x.strip()) for x in fp_label]
        fp_label.close()
        self.label = labels

    def __getitem__(self, index):
        img = Image.open(os.path.join(self.img_path, self.img_filename[index]))
        img = img.convert('RGB')
        if self.transform is not None:
            img = self.transform(img)
        label = torch.LongTensor([self.label[index]])
        return img, label, index
    def __len__(self):
        return len(self.img_filename)

参考拓展:https://www.daimajiaoliu.com/daima/4ede05ecd1003fc

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值