PyTorch构建数据管道

在Pytorch中构建图片数据管道有两种方法:

  1. 使用torchvision中datasets.ImageFolder来读取图片然后用DataLoader来并行加载;
  2. 通过继承torch.utils.data.Dataset实现用户自定义读取逻辑然后DataLoader来并行加载。
    第二种方法是读取自定义数据集的通用方法,既可以加载图片数据集,也可以读取文本数据集。

第一种方法如下:


import torch
from torch import nn
from torch.utils.data import Dataset,DataLoader
from torchvision import transforms,datasets

transform_train = transforms.Compose(
    [transforms.ToTensor()])
transform_valid = transforms.Compose(
    [transforms.ToTensor()])
#%%


ds_train = datasets.ImageFolder("/media/yfh/hd/resources/databases/cifar2/train/",
            transform = transform_train,target_transform= lambda t:torch.tensor([t]).float())
ds_valid = datasets.ImageFolder("/media/yfh/hd/resources/databases/cifar2//test/",
            transform = transform_train,target_transform= lambda t:torch.tensor([t]).float())  

print(ds_train.class_to_idx)  #{'0_airplane': 0, '1_automobile': 1}


%matplotlib inline
%config InlineBackend.figure_format = 'svg'

#查看部分样本
from matplotlib import pyplot as plt

plt.figure(figsize=(8,8))
for i in range(9):
    img,label = ds_train[i]
    img = img.permute(1,2,0)
    ax=plt.subplot(3,3,i+1)
    ax.imshow(img.numpy())
    ax.set_title("label = %d"%label.item())
    ax.set_xticks([])
    ax.set_yticks([])
plt.show()



# Pytorch的图片默认顺序是 Batch,Channel,Width,Height
for x,y in dl_train:
    print(x.shape,y.shape) # torch.Size([50, 3, 32, 32]) torch.Size([50, 1])
    break

第二种方法:
参考博客:https://blog.csdn.net/l8947943/article/details/103733473
1、继承torch.utils.data.Dataset并且重写_getitem_()__len__()方法


import torch
import numpy as np


# 定义GetLoader类,继承Dataset方法,并重写__getitem__()和__len__()方法
class GetLoader(torch.utils.data.Dataset):
	# 初始化函数,得到数据
    def __init__(self, data_root, data_label):
        self.data = data_root
        self.label = data_label
    # index是根据batchsize划分数据后得到的索引,最后将data和对应的labels进行一起返回
    def __getitem__(self, index):
        data = self.data[index]
        labels = self.label[index]
        return data, labels
    # 该函数返回数据大小长度,目的是DataLoader方便划分,如果不知道大小,DataLoader会一脸懵逼
    def __len__(self):
        return len(self.data)

# 随机生成数据,大小为10 * 20列
source_data = np.random.rand(10, 20)
# 随机生成标签,大小为10 * 1列
source_label = np.random.randint(0,2,(10, 1))
# 通过GetLoader将数据进行加载,返回Dataset对象,包含data和labels
torch_data = GetLoader(source_data, source_label)

2、Dataloader并行加载数据

from torch.utils.data import DataLoader

# 读取数据
datas = DataLoader(torch_data, batch_size=6, shuffle=True, drop_last=False, num_workers=2)

3、通过迭代器查看数据

for i, data in enumerate(datas):
	# i表示第几个batch, data表示该batch对应的数据,包含data和对应的labels
    print("第 {} 个Batch \n{}".format(i, data))

结果如下

第 0 个Batch 
[tensor([[0.9599, 0.8651, 0.2255, 0.0347, 0.0917, 0.9121, 0.1441, 0.9894, 0.9187,
         0.8801, 0.0485, 0.8577, 0.8432, 0.0217, 0.2609, 0.7885, 0.4271, 0.6010,
         0.4486, 0.4694],
        [0.0324, 0.2408, 0.4294, 0.6394, 0.9968, 0.4153, 0.5748, 0.9075, 0.8704,
         0.2500, 0.5978, 0.0943, 0.9280, 0.8045, 0.5619, 0.4407, 0.0798, 0.0098,
         0.3712, 0.4186],
        [0.5342, 0.7337, 0.1067, 0.2624, 0.1423, 0.3960, 0.0439, 0.3460, 0.0646,
         0.8649, 0.3192, 0.4209, 0.8045, 0.5303, 0.5436, 0.8913, 0.5350, 0.4947,
         0.3241, 0.1768],
        [0.8492, 0.0950, 0.2038, 0.0865, 0.3746, 0.4050, 0.5040, 0.5224, 0.5192,
         0.7546, 0.3538, 0.1554, 0.9970, 0.2397, 0.6701, 0.1990, 0.6772, 0.5123,
         0.9840, 0.5672],
        [0.7546, 0.3447, 0.0682, 0.8481, 0.7333, 0.3628, 0.6533, 0.1724, 0.6848,
         0.5730, 0.6727, 0.4741, 0.9487, 0.4466, 0.8268, 0.5067, 0.5117, 0.5438,
         0.1003, 0.5986],
        [0.3786, 0.8163, 0.3150, 0.5195, 0.9077, 0.1611, 0.8182, 0.2060, 0.3715,
         0.5046, 0.5230, 0.8975, 0.7656, 0.9408, 0.8220, 0.8867, 0.0290, 0.8946,
         0.7680, 0.2677]], dtype=torch.float64), tensor([[1],
        [0],
        [0],
        [1],
        [0],
        [1]])]
第 1 个Batch 
[tensor([[0.4901, 0.5575, 0.2097, 0.1098, 0.5834, 0.0306, 0.1047, 0.4017, 0.7830,
         0.9238, 0.3405, 0.2155, 0.3767, 0.2743, 0.8154, 0.3525, 0.5874, 0.8691,
         0.0262, 0.2904],
        [0.9268, 0.8384, 0.9948, 0.2149, 0.1508, 0.2278, 0.6399, 0.3555, 0.5254,
         0.6366, 0.9150, 0.0842, 0.4703, 0.3684, 0.6052, 0.1764, 0.5499, 0.7318,
         0.4513, 0.3531],
        [0.5359, 0.9277, 0.2643, 0.3641, 0.3117, 0.7986, 0.7952, 0.6529, 0.4539,
         0.4004, 0.4223, 0.2886, 0.9924, 0.5950, 0.9733, 0.4068, 0.1523, 0.4911,
         0.7287, 0.4250],
        [0.0345, 0.3635, 0.9745, 0.2807, 0.1577, 0.4595, 0.6639, 0.1265, 0.7047,
         0.1411, 0.4033, 0.2724, 0.4256, 0.1492, 0.8040, 0.1352, 0.4836, 0.7783,
         0.7087, 0.0935]], dtype=torch.float64), tensor([[1],
        [1],
        [0],
        [1]])]
torch.utils.data.DataLoader(dataset,batch_size,shuffle,drop_last,num_workers)
'''
dataset: 加载torch.utils.data.Dataset对象数据
batch_size: 每个batch的大小
shuffle:是否对数据进行打乱
drop_last:是否对无法整除的最后一个datasize进行丢弃
num_workers:表示加载的时候子进程数
'''

注:本文参考博客链接已贴出。

  • 2
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值