PyTorch 入门:自定义数据加载

之前学习tensorflow时也学习了它的数据加载,不过在网上看了很多教程后还是有很多小问题,不知道为什么在别人电脑上可以运行但是我的就不行(把我头搞晕了),很烦,这时想起之前听导师说PyTorch容易入门上手,所以果断去学了PyTorch,写这篇博文的目的就是总结学到的,然后记录下来,也希望以后学到新的知识或技术能够用写博客的方式记录下来,这样有助于形成比较好的知识体系,也方便以后温故知新。

在进行深度学习实验前,必须准备数据,而在准备数据的时候有一个步骤就是把数据封装成符合深度学习模型要求的数据形式,方便模型读取。

在PyTorch中,有一个torch.utils.data.Dataset类,这是一个抽象类,其他所有不管是公开的官方数据集还是自定义数据集都必须继承这个抽象类(比如MNIST数据集),继承这个抽象类的同时必须重写它的两个函数:__len__()  和    __getitem__()。

__len__():返回数据集的大小,比如我的数据集有500张图片,那么就返回500

__getitem__():返回一张图片

具体怎么定义自定义数据集,代码如下:

import torch
from torch.utils.data import Dataset #首先导入这个抽象类
from skimage import io,transform


class MyDataset(Dataset):
    """
    这是一个初始化函数,相当于c++的构造函数,定义类的传入参数和初始化
    root_dir是数据集的路径,transform是一个数据处理操作
    """
    def __init__(self,root_dir,transform=None):  

        #os.listdir函数读取路径下所以文件的文件名,并组成一个列表并返回

        self.file = os.listdir(root_dir)  
     
        self.root_dir = root_dir

        self.transform = transform

       
        
    def __len__(self):
        return len(self.file)   #返回这给列表的大小
    
    
    def __getitem__(self,index):
    #将传入路径和文件名组成一个新的地址,这个数据就是单个数据的具体地址,方便之后以地址读取该数据
        img_name = os.path.join(self.root_dir, self.file[index])
    #从文件名中获取标签(我的数据标签在文件名中,比如一张狗的图片名:25_dog.jpg,25是序号,dog是
    #标签)
        if img_name[-7:-4] == 'dog':
            label = 0
        else: label = 1
        #根据上面获得的具体地址,读取这张图片
        image = io.imread(img_name)

        #对图片进行缩放为一个大小,方便深度学习模型处理
        image = transform.resize(image,(128,128))

        #对图片的维度进行转换,
        #numpy的三个维度顺序为:H * W * C
        #而torch的张量维度顺序:C * H * W ,所以模型要处理它必须转换成torch的形式
        image = image.transpose((2, 0, 1))

        #返回数据和标签
        return image,label

自定义数据集定义好了,那么怎么批量加载它呢,PyTorch使用多线程加载数据,模型需要使用时才加载进内存让模型读取,而使用批量读取数据必须使用PyTorch的torch.utils.data.DataLoader类,使用方法如下:

mydataloader = torch.utils.data.DataLoader()

#首先实例化一个自定义数据类
dataset = MyDataset(root_dir='./dog_vs_cat/train/',transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ]))

#然后实例化一个数据加载类
dataloader = torch.utils.data.DataLoader(dataset,batch_size=100,shuffle=True,num_workers=0)

第一个参数是要加载的数据类

第二个参数是数据加载时每个批次多少数据

第三个参数设置数据加载时是否打乱数据

第四个参数设置多线程的个数,默认值是0,表示单个线程

 

然后就可以在迭代器中使用了

for batch_idx, (image, label) in enumerate(dataloader): 
         

batch_idx 表示迭代器返回的自带序号

(image,label)表示返回的数据和标签

最后就可以将返回的数据和标签输入到模型中训练了,完美!

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值