PyTorch 中自定义数据集的读取方法

显然我们在学习深度学习时,不能只局限于通过使用官方提供的MNSIT、CIFAR-10、CIFAR-100这样的数据集,很多时候我们还是需要根据自己遇到的实际问题自己去搜集数据,然后制作数据集(收集数据集的方法有很多,这里就不过多的展开了)。这里只介绍数据集的读取。

  1. 自定义数据集的方法
    首先创建一个Dataset类
    在这里插入图片描述
    在代码中:
    def init() 一些初始化的过程写在这个函数下
    def len() 返回所有数据的数量,比如我们这里将数据划分好之后,这里仅仅返回的是被处理后的关系
    def getitem() 回数据和标签

  2. 补充代码
    上述已经将框架打出来了,接下来就是将框架填充完整就行了,下面是完整的代码,代码的解释说明我也已经写在其中了

# -*- coding: utf-8 -*-
# @Author  : 胡子旋
# @Email   :1017190168@qq.com

import torch
import os,glob
import visdom
import time
import torchvision
import random,csv
from torch.utils.data import Dataset,DataLoader
from torchvision import transforms
from PIL import Image

class pokemom(Dataset):
    def __init__(self,root,resize,mode,):
        super(pokemom,self).__init__()
        # 保存参数
        self.root=root
        self.resize=resize
        # 给每一个类做映射
        self.name2label={}  # "squirtle":0 ,"pikachu":1……
        for name in sorted(os.listdir(os.path.join(root))):
            # 过滤掉文件夹
            if not os.path.isdir(os.path.join(root,name)):
                continue
            # 保存在表中;将最长的映射作为最新的元素的label的值
            self.name2label[name]=len(self.name2label.keys())
        print(self.name2label)
        # 加载文件
        self.images,self.labels=self.load_csv('images.csv')
        # 裁剪数据
        if mode=='train':
            self.images=self.images[:int(0.6*len(self.images))]   # 将数据集的60%设置为训练数据集合
            self.labels=self.labels[:int(0.6*len(self.labels))]   # label的60%分配给训练数据集合
        elif mode=='val':
            self.images = self.images[int(0.6 * len(self.images)):int(0.8 * len(self.images))]  # 从60%-80%的地方
            self.labels = self.labels[int(0.6 * len(self.labels)):int(0.8 * len(self.labels))]
        else:
            self.images = self.images[int(0.8 * len(self.images)):]   # 从80%的地方到最末尾
            self.labels = self.labels[int(0.8 * len(self.labels)):]
        # image+label 的路径
    def load_csv(self,filename):
        # 将所有的图片加载进来
        # 如果不存在的话才进行创建
        if not os.path.exists(os.path.join(self.root,filename)):
            images=[]
            for name in self.name2label.keys():
                images+=glob.glob(os.path.join(self.root,name,'*.png'))
                images+=glob.glob(os.path.join(self.root, name, '*.jpg'))
                images += glob.glob(os.path.join(self.root, name, '*.jpeg'))
            print(len(images),images)
            # 1167 'pokeman\\bulbasaur\\00000000.png'
            # 将文件以上述的格式保存在csv文件内
            random.shuffle(images)
            with open(os.path.join(self.root,filename),mode='w',newline='') as f:
                writer=csv.writer(f)
                for img in images:    #  'pokeman\\bulbasaur\\00000000.png'
                    name=img.split(os.sep)[-2]
                    label=self.name2label[name]
                    writer.writerow([img,label])
                print("write into csv into :",filename)

        # 如果存在的话就直接的跳到这个地方
        images,labels=[],[]
        with open(os.path.join(self.root, filename)) as f:
            reader=csv.reader(f)
            for row in reader:
                # 接下来就会得到 'pokeman\\bulbasaur\\00000000.png' 0 的对象
                img,label=row
                # 将label转码为int类型
                label=int(label)
                images.append(img)
                labels.append(label)
        # 保证images和labels的长度是一致的
        assert len(images)==len(labels)
        return images,labels


    # 返回数据的数量
    def __len__(self):
        return len(self.images)   # 返回的是被裁剪之后的关系

    def denormalize(self, x_hat):

        mean = [0.485, 0.456, 0.406]
        std = [0.229, 0.224, 0.225]
        mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1)
        std = torch.tensor(std).unsqueeze(1).unsqueeze(1)
        # print(mean.shape, std.shape)
        x = x_hat * std + mean
        return x
    # 返回idx的数据和当前图片的label
    def __getitem__(self,idx):
        # idex-[0-总长度]
        # retrun images,labels
        # 将图片,label的路径取出来
        # 得到的img是这样的一个类型:'pokeman\\bulbasaur\\00000000.png'
        # 然而label得到的则是 0,1,2 这样的整形的格式
        img,label=self.images[idx],self.labels[idx]
        tf=transforms.Compose([
            lambda x:Image.open(x).convert('RGB'),  # 将t图片的路径转换可以处理图片数据
            # 进行数据加强
            transforms.Resize((int(self.resize*1.25),int(self.resize*1.25))),
            # 随机旋转
            transforms.RandomRotation(15),   # 设置旋转的度数小一些,否则的话会增加网络的学习难度
            # 中心裁剪
            transforms.CenterCrop(self.resize),   # 此时:既旋转了又不至于导致图片变得比较的复杂
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485,0.456,0.406],
                                 std=[0.229,0.224,0.225])

        ])
        img=tf(img)
        label=torch.tensor(label)
        return img,label




def main():
    # 验证工作
    viz=visdom.Visdom()

    db=pokemom('pokeman',64,'train')  # 这里可以改变大小 224->64,可以通过visdom进行查看
    # 可视化样本
    x,y=next(iter(db))
    print('sample:',x.shape,y.shape,y)
    viz.image(db.denormalize(x),win='sample_x',opts=dict(title='sample_x'))
    # 加载batch_size的数据
    loader=DataLoader(db,batch_size=32,shuffle=True,num_workers=8)
    for x,y in loader:
        viz.images(db.denormalize(x),nrow=8,win='batch',opts=dict(title='batch'))
        viz.text(str(y.numpy()),win='label',opts=dict(title='batch-y'))
        # 每一次加载后,休息10s
        time.sleep(10)

if __name__ == '__main__':
    main()
  • 19
    点赞
  • 147
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论
### 回答1: 在 PyTorch 读取自定义数据集的一般步骤如下: 1. 定义数据集类:首先需要定义一个数据集类,继承自 `torch.utils.data.Dataset` 类,并实现 `__getitem__` 和 `__len__` 方法。在 `__getitem__` 方法,根据索引返回一个样本的数据和标签。 2. 加载数据集:使用 `torch.utils.data.DataLoader` 类加载数据集,可以设置批量大小、多线程读取数据等参数。 下面是一个简单的示例代码,演示如何使用 PyTorch 读取自定义数据集: ```python import torch from torch.utils.data import Dataset, DataLoader class CustomDataset(Dataset): def __init__(self, data, targets): self.data = data self.targets = targets def __getitem__(self, index): x = self.data[index] y = self.targets[index] return x, y def __len__(self): return len(self.data) # 加载训练集和测试集 train_data = ... train_targets = ... train_dataset = CustomDataset(train_data, train_targets) train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) test_data = ... test_targets = ... test_dataset = CustomDataset(test_data, test_targets) test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False) # 训练模型 for epoch in range(num_epochs): for batch_idx, (data, targets) in enumerate(train_loader): # 前向传播、反向传播,更新参数 ... ``` 在上面的示例代码,我们定义了一个 `CustomDataset` 类,加载了训练集和测试集,并使用 `DataLoader` 类分别对它们进行批量读取。在训练模型时,我们可以像使用 PyTorch 自带的数据集一样,循环遍历每个批次的数据和标签,进行前向传播、反向传播等操作。 ### 回答2: PyTorch是一个开源的深度学习框架,它提供了丰富的功能用于读取和处理自定义数据集。下面是一个简单的步骤来读取自定义数据集。 首先,我们需要定义一个自定义数据集类,该类应继承自`torch.utils.data.Dataset`类,并实现`__len__`和`__getitem__`方法。`__len__`方法应返回数据集的样本数量,`__getitem__`方法根据给定索引返回一个样本。 ```python import torch from torch.utils.data import Dataset class CustomDataset(Dataset): def __init__(self, data): self.data = data def __len__(self): return len(self.data) def __getitem__(self, idx): sample = self.data[idx] return torch.tensor(sample) ``` 接下来,我们可以创建一个数据集实例并传入自定义数据。假设我们有一个包含多个样本的列表 `data`。 ```python data = [[1, 2, 3], [4, 5, 6], [7, 8, 9]] dataset = CustomDataset(data) ``` 然后,我们可以使用`torch.utils.data.DataLoader`类加载数据集,并指定批次大小、是否打乱数据等。 ```python batch_size = 2 dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True) ``` 现在,我们可以迭代数据加载器来获取批次的样本。 ```python for batch in dataloader: print(batch) ``` 上面的代码将打印出两个批次的样本。如果`shuffle`参数设置为`True`,则每个批次的样本将是随机的。 总而言之,PyTorch提供了简单而强大的工具来读取和处理自定义数据集,可以根据实际情况进行适当修改和扩展。 ### 回答3: PyTorch是一个流行的深度学习框架,可以用来训练神经网络模型。要使用PyTorch读取自定义数据集,可以按照以下几个步骤进行: 1. 准备数据集:将自定义数据集组织成合适的目录结构。通常情况下,可以将数据集分为训练集、验证集和测试集,每个集合分别放在不同的文件夹。确保每个文件夹的数据按照类别进行分类,以便后续的标签处理。 2. 创建数据加载器:在PyTorch,数据加载器是一个有助于有效读取和处理数据的类。可以使用`torchvision.datasets.ImageFolder`类创建一个数据加载器对象,通过传入数据集的目录路径来实现。 3. 数据预处理:在将数据传入模型之前,可能需要对数据进行一些预处理操作,例如图像变换、标准化或归一化等。可以使用`torchvision.transforms`的类来实现这些预处理操作,然后将它们传入数据加载器。 4. 创建数据迭代器:数据迭代器是连接数据集和模型的重要接口,它提供了一个逐批次加载数据的功能。可以使用`torch.utils.data.DataLoader`类创建数据迭代器对象,并设置一些参数,例如批量大小、是否打乱数据等。 5. 使用数据迭代器:在训练时,可以使用Python的迭代器来遍历数据集并加载数据。通常,它会在每个迭代步骤返回一个批次的数据和标签。可以通过`for`循环来遍历数据迭代器,并在每个步骤处理批次数据和标签。 这样,我们就可以在PyTorch成功读取并处理自定义数据集。通过这种方式,我们可以更好地利用PyTorch的功能来训练和评估自己的深度学习模型。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

陶陶name

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值