Pytorch学习之数据加载


一、Dataset类

这个类可以看成是自定义的数据集类(是一个抽象类,不能直接实例化,只能继承)
代码如下(示例):

class Mydataset(Dataset):
	def __init__(self,):
		pass
	def __len__(self):
		pass
	def __getitem__(self,idx)
		pass

一、当数据集比较小时,可以把整个数据集放入init中(即放入内存中),再根据getitem的索引来读出
二、当数据集比较大时(如图像数据集),一般要先做一个列表,来记录下每张图像的id。在getitem函数里读取列表中第i个图像id,系统会从文件夹中将图片读出,返回

二、torchvision.transforms.Compose使用

这个类的主要作用是串联多个图片变换的操作。Compose里面的参数实际上就是个列表

通常预处理步骤:

  1. 所有图片转化为相同大小。
  2. 把图片数据集转换为Pytorch张量
  3. 用数据集的均值和标准差把数据集归一化

代码如下(示例):

transforms = transforms.Compose([transforms.Resize((224, 224)),  # cannot 224, must (224, 224)
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}

三、torchvision.datasets.ImageFolder使用详解

ImageFolder是一个通用的数据加载器,数据如放在文件夹中的图片
使用详情

dataset=torchvision.datasets.ImageFolder(
                       root, 
                       transform=None, 
                       target_transform=None, 
                       loader=<function default_loader>, 
                       is_valid_file=None)

1.参数详解
root:图片存储的根目录,即各类别文件夹所在目录的上一级目录。
transform:对图片进行预处理的操作(函数),原始图片作为输入,返回一个转换后的图片。
target_transform:对图片类别进行预处理的操作,输入为 target,输出对其的转换。如果不传该参数,即对 target 不做任何转换,返回的顺序索引 0,1, 2…即label是按照文件夹命名从0开始的数字
loader:表示数据集加载方式,通常默认加载方式即可。
is_valid_file:获取图像文件的路径并检查该文件是否为有效文件的函数(用于检查损坏文件)

2.返回的dataset都有以下三种属性:

self.classes:用一个 list 保存类别名称
self.class_to_idx:类别对应的索引,与不做任何转换返回的 target 对应
self.imgs:保存(img-path, class) tuple的 list

代码如下(示例):

train_dataset = datasets.ImageFolder(root=./data/train, 
                                     transform=transforms)

我们得到的train_dataset,它的结构就是[(img_data,class_id),(img_data,class_id),…]

print(train_dataset[995])  # 第995个图片  class_id=1
'''
输出:
(tensor([[[-0.1765, -0.1686, -0.1686,  ..., -0.2941, -0.2941, -0.3020],
         [-0.1765, -0.1765, -0.1608,  ..., -0.2941, -0.2941, -0.2863],
         [-0.1765, -0.1765, -0.1608,  ..., -0.2863, -0.2863, -0.2784],
         ...,
         [-0.2078, -0.1922, -0.1843,  ..., -0.1608, -0.1608, -0.1608],
         [-0.1608, -0.1922, -0.1843,  ..., -0.1608, -0.1608, -0.1608],
         [-0.1922, -0.1686, -0.2000,  ..., -0.1686, -0.1608, -0.1529]],

        [[-0.2392, -0.2314, -0.2314,  ..., -0.3176, -0.3176, -0.3176],
         [-0.2392, -0.2392, -0.2235,  ..., -0.3176, -0.3098, -0.3020],
         [-0.2392, -0.2392, -0.2235,  ..., -0.3176, -0.3176, -0.3098],
         ...,
         [-0.3490, -0.3569, -0.3333,  ..., -0.3020, -0.3020, -0.3020],
         [-0.3098, -0.3412, -0.3333,  ..., -0.3020, -0.3020, -0.3020],
         [-0.3490, -0.3098, -0.3490,  ..., -0.3098, -0.3020, -0.2941]],

        [[-0.7255, -0.7176, -0.7176,  ..., -0.8745, -0.8824, -0.8824],
         [-0.7255, -0.7255, -0.7098,  ..., -0.8745, -0.8902, -0.8824],
         [-0.7255, -0.7255, -0.7098,  ..., -0.8588, -0.8745, -0.8667],
         ...,
         [-0.8353, -0.8745, -0.7882,  ..., -0.6784, -0.6784, -0.6784],
         [-0.7882, -0.8588, -0.7882,  ..., -0.6784, -0.6784, -0.6784],
         [-0.8039, -0.7882, -0.8039,  ..., -0.6941, -0.6784, -0.6706]]]), 1)
'''

再看三个属性

print(dataset.classes)  #根据分的文件夹的名字来确定的类别
print(dataset.class_to_idx) #按顺序为这些类别定义索引为0,1...
print(dataset.imgs) #返回从所有文件夹中得到的图片的路径以及其类别
'''
输出:
['cat', 'dog']
{'cat': 0, 'dog': 1}
[('./data/train\\cat\\1.jpg', 0), 
 ('./data/train\\cat\\2.jpg', 0), 
 ('./data/train\\dog\\1.jpg', 1), 
 ('./data/train\\dog\\2.jpg', 1)]
'''

四、按批加载数据-----DataLoader类

数据集过大不能一次性全部加载到内存里,可按批次来加载数据
使用详情

train_loader = torch.utils.data.DataLoader(train_dataset,	# 导入的训练集
                                           batch_size=4, 	# 每批训练的样本数
                                           shuffle=True,	# 是否打乱训练集
                                           num_workers=0)	# 使用线程数,在windows下设置为0。多线程能提高读取效率
  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
PyTorch是一个基于Python的科学计算包,其主要功能是进行张量计算和深度学习模型构建。在深度学习中,数据加载是一个重要的环节,PyTorch提供了一些工具和函数来简化数据加载的过程。 PyTorch数据加载主要涉及到两个类:`torch.utils.data.Dataset`和`torch.utils.data.DataLoader`。其中,`Dataset`类用于表示数据集,而`DataLoader`类则用于对数据集进行加载和处理。 使用PyTorch进行数据加载的基本步骤如下: 1. 定义数据集:需要继承`torch.utils.data.Dataset`类,并实现`__len__`和`__getitem__`方法。其中,`__len__`方法返回数据集的大小,`__getitem__`方法用于获取指定索引的数据。 2. 创建数据集实例:将定义好的数据集实例化,并传入相应的参数(如文件路径等)。 3. 创建数据加载器:使用`torch.utils.data.DataLoader`类创建数据加载器,可以指定批次大小、是否打乱数据、多进程等参数。 4. 迭代数据:使用for循环迭代数据加载器,每次迭代返回一个批次的数据。 下面是一个简单的示例代码,用于加载MNIST数据集: ```python import torch from torch.utils.data import Dataset, DataLoader from torchvision import datasets, transforms # 定义自己的数据集类 class MyDataset(Dataset): def __init__(self, path): self.data = torch.load(path) self.transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ]) def __len__(self): return len(self.data) def __getitem__(self, index): x, y = self.data[index] x = self.transform(x) return x, y # 创建数据集实例 train_dataset = MyDataset('mnist/train.pt') test_dataset = MyDataset('mnist/test.pt') # 创建数据加载器 train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True) test_loader = DataLoader(test_dataset, batch_size=64, shuffle=True) # 迭代数据 for batch_idx, (data, target) in enumerate(train_loader): # 对批次数据进行训练或测试 ... ```

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

想要躺平的一枚

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值