pytorch中 torch.utils.data的用法 ----加载数据篇

本文介绍了如何使用PyTorch的Dataset类处理两种不同形式的数据集,一种是根据文件夹名字作为标签,另一种是图片与标签信息分开存放。通过重写__getitem__和__len__函数,实现了数据集的读取和标签获取。并提供了具体代码示例进行解释。
摘要由CSDN通过智能技术生成

1. 我们将torch.utils.data 别名为Dataset ,即:

from torch.utils.data import Dataset

2.通过help()语句查看Dataset的用法,其实还有一种简便方式查看排版更清晰的使用说明:

通过PyCharm中的terminal输入jupyter notebook 语句,打开网页版的jupyter notebook

在要查询的函数后打两个“?” ,运行即可得到排版良好的帮助文档:

对Dataset的帮助文档简要概括就是:

Dataset类是一个抽象类,所有的数据集都要继承这个Dataset类,并且所有的子类都需要重写__getitem__函数
__getitem__函数需要做的事情是:获取每个数据及其应的label
同时,我们还可以选择性的重写__len__函数,它的default形式是返回数据集的size

原文如下:

 3.实现   ----  一个图片二分类问题(给定一张图片,分类其到底是蜜蜂还是蚂蚁):

(一)第一种常见的数据组织形式及其处理:

1.数据集组织形式:

数据集名字是P567DataSet,P567DataSet数据集包括两个文件夹:训练集的train文件夹 和 测试集的val文件夹

训练集train文件夹包括两个文件,一个叫ants,一个叫 bees 。测试集val文件及同理

由于该图片二分类的输入就是一张张图片,输出的就是个分类标签 bee 或者 ant。由于输出的标签很简单,因此,P567DataSet数据集中对各个图片的标签就写在了文件夹名字上,即:ants文件夹中就是蚂蚁的图片,bees文件夹中就是蜜蜂的图片。ants和bees文件夹中的图片名字都是随便起的,没有啥特别的含义。文件位置和内容大致如下图所示:

 2.处理方式:

"""
torch.utils.data类的学习,一般将该类重命名为 Dataset
参考链接:  (堆老师讲得好好,大家积极点赞投币收藏哇!)
https://www.bilibili.com/video/BV1hE411t7RN?p=6&vd_source=6ef95a7fa007076dc9195eba914171b5
https://www.bilibili.com/video/BV1hE411t7RN?p=7&spm_id_from=pageDriver&vd_source=6ef95a7fa007076dc9195eba914171b5
"""

#P567DataSet的数据方式读取


import torch
from torch.utils.data import Dataset
from PIL import Image  #一个常见的用于展示图片的库
import os   #os是python当中关于系统的一个库。在本段代码中,用于获取图片的地址和地址之间的拼接工作


# Dataset类是一个抽象类,所有的数据集使用都要继承这个Dataset类,并且所有的子类都需要重写__getitem__函数
# __getitem__函数需要做的事情是:当给定index后,就可以得到数据集中对应index的 训练内容输入x 和 其对应的label值y
# 同时,我们还可以选择性的重写__len__函数,它的default形式是返回数据集的size


class MyDataSet(Dataset):
    def __init__(self , t_root_dir , t_type_name):
        #之所以用这种组织形式是因为本例中数据的整理形式是: 训练数据ants类的图片都放在名为ants的文件夹中了,ants文件夹的名字就是该训练图片的label
        self.root_dir = t_root_dir
        self.type_dir = t_type_name

        # 由于要实现的__getitem__函数要求:给出编号 idx 时,从数据列表中返回该 idx 对应的 数据内容 和 其对应的label
        # 所以,首先要实现一个列表,包括 这个训练数据集的 全部数据的 内容(input xi)和 其标签(output yi) .
        # 其组织形式是[ [input x1 , output y1] , [input x2 , output y2] , [input x3 , output y3] , [input x4 , output y4] , ...  ]
        # 那么对于本例来说,训练数据内容就是一张张图片,那么 input xi 就是对应的图片的地址 ,ouput yi 就是对应图片的label
        self.whole_dir = os.path.join(self.root_dir, self.type_dir)  # 得到想要的 图片数据集的 地址
                                                                     # 函数os.join(AA,BB)的意思是:将AA和BB两个路径拼接在一起,那为啥不用字符串拼接呢? --这是因为这涉及代码的移植问题,在win中,路径是用\\来间隔,但是在ubuntu中是用/来间隔的
                                                                     # PS:函数os.join()中的参数可以是多个,eg,os.join(AA,BB,CC,DD) 意思就是把AA,BB,CC,DD四个地址拼在一起

        self.img_path_list = sorted(os.listdir(self.whole_dir))  # 函数os.listdir(XX)的含义是:将参数XX位置处的 所有文件的 地址 以列表中元素的形式 保存到一个列表中 并返回。  ==> 此时得到的仅是whole_dir文件夹下包含的文件的 名称列表,其实准确来说也不太能称得上是地址列表,只是个相对路径
                                                                 # 由于,os.listdir()得到的顺序是随机的,它的输出结果并不是按照某种特定顺序来的,这样会导致一些问题,比如:若图片和label之间不是现在这种对应关系,而是另有一个文件将 图片与label相对应(如本博客中的例(2)的数据组织形式),此时的顺序随机就会导致图片名与标签名不对应


    def __getitem__(self,idx):  #该函数要实现的内容是: 当给出编号 idx 时,从数据列表中返回该idx对应的 数据内容 和 其对应的label
        #根据已经在构造函数值中建好的self.img_path_list , 对应出 idx 相应的 数据内容 和 其label
        img_idx_path = os.path.join( self.whole_dir , self.img_path_list[idx] )   #拼接得到idx对应的图片的全部地址,self.img_path_list[idx]只是图片的名字。  拼接后长这样: 'P567DataSet/train/ants/0013035.jpg'
        img_idx_label = self.type_dir

        return img_idx_path , img_idx_label


#验证部分
train_ants_data = MyDataSet("P567DataSet/train","ants")
train_bees_data = MyDataSet("P567DataSet/train","bees")

print(train_ants_data[0])

imgpathbee1 , labelbee1 = train_bees_data[0]
imgpathant1 , labelant1 = train_ants_data[0]

imgant1 = Image.open(imgpathant1)  #利用PIL中的Image库,输入图片路径,返回一个Image类型的实例
imgbee1 = Image.open(imgpathbee1)

imgbee1.show()  # 将Image实例imgbee1调用.show()函数展示该图片
imgant1.show()

为了方便理解,我贴了两张运行时,对于变量img_path_list都包含了什么的截图:

 

 (二)第二种常见的数据组织形式及其处理:

1.数据集组织形式:

数据集名字是SecondTypeDataset,SecondTypeDataset 数据集包括两个文件夹:训练集的train文件夹 和 测试集的val文件夹

训练集train文件夹包括四个文件夹,内容是蚂蚁图片的ants_image文件夹;内容是蚂蚁图片对应label 标签的txt文件;内容是蜜蜂图片的bees_image文件夹;内容是蜜蜂图片对应 label 标签的 txt文件。测试集val文件同理。

这种数据集的组织形式时更常见的,(以图片分类作为实际例子来举例讲述),因为往往一张图片所包含的信息不仅仅是拥有一个简单的类别标签。

比如说对于更高级的图像分析任务:如:确定图像中是否有汽车,并且汽车处于图像的那个位置。那么对于一张图片其标签至少有两个,第一个是tag标签,表明图片中是否有汽车,第二个是位置标签,表明图片中的汽车是在那个位置上。

那么,针对于这种情况,我们就无法再将这一对的标签信息体现在在文件夹地方命名或者图片自己的命名中了。

一般选择的方式是,给定图片,其命名为aa.jpg ,然后有一个同名的txt文件:aa.txt 。在aa.txt文件中,写入aa.jpg图片的一些label信息。然后将这些图片文件整理到一个文件夹中,这些带有label信息的txt文件整理到另一个文件夹中。需要匹配的时候,由于图片文件和label文件的文件名相同,所以就很容易匹配上,方便数据的进一步整理。

对应到本任务:虽然是简单的二分类图像识别,但是,数据集也是将图片和label信息分开放置。以训练集 train 文件夹 中的 ants_image文件夹和其对应的ants_label文件夹为例:ants_image文件夹中包含多张蚂蚁的图片,这些图片的命名是随便起的,没啥含义。但是,在ants_label文件夹中,包含许多txt文件,这些 txt 文件的名称与 ants_image 文件夹中的图片的名称是一一对应的:ants_image 文件夹中 同名的图片 其标签在 ants_label中

数据整理方式如下图所示:

 2.处理方式:

"""
torch.utils.data类的学习,一般将该类重命名为 Dataset
参考链接:  (堆老师讲得好好,大家积极点赞投币收藏哇!)
https://www.bilibili.com/video/BV1hE411t7RN?p=6&vd_source=6ef95a7fa007076dc9195eba914171b5
https://www.bilibili.com/video/BV1hE411t7RN?p=7&spm_id_from=pageDriver&vd_source=6ef95a7fa007076dc9195eba914171b5
"""

#SecondTypeDataset的数据方式读取

import torch
from torch.utils.data import Dataset
from PIL import Image  #一个常见的用于展示图片的库
import os   #os是python当中关于系统的一个库。在本段代码中,用于获取图片的地址和地址之间的拼接工作


# Dataset类是一个抽象类,所有的数据集使用都要继承这个Dataset类,并且所有的子类都需要重写__getitem__函数
# __getitem__函数需要做的事情是:当给定index后,就可以得到数据集中对应index的 训练内容输入x 和 其对应的label值y
# 同时,我们还可以选择性的重写__len__函数,它的default形式是返回数据集的size


class MyDataSet(Dataset):
    def __init__(self , t_root_dir , t_data_type):  #t_root_dir = "SecondTypeDataset/train"    t_type_name只能有俩选项:ants , bees
        self.root_dir = t_root_dir
        #此时的t_data_type不再代表着位置路径,现在它只是代表数据的种类了
        self.data_type = t_data_type
        if self.data_type == "ants" :
            self.img_dir = "ants_image"
            self.label_dir = "ants_label"
        elif self.data_type == "bees" :
            self.img_dir = "bees_image"
            self.label_dir = "bees_label"
        else :
            print("请在第二个参数:t_data_type处选择 ants 或 bees 输入, 否则无法建立正确的MyDataset实例 ")
            exit()

        self.whole_img_dir = os.path.join(self.root_dir,self.img_dir)       # "SecondTypeDataset/train/ants_image"
        self.img_path_list = sorted( os.listdir(self.whole_img_dir) )
        self.whole_label_dir = os.path.join(self.root_dir,self.label_dir)   # "SecondTypeDataset/train/ants_label"
        self.label_path_list = sorted( os.listdir(self.whole_label_dir) )


    def __getitem__(self,idx):  #该函数要实现的内容是: 当给出编号 idx 时,从数据列表中返回该idx对应的 数据内容 和 其对应的label
        #根据已经在构造函数值中建好的self.img_path_list , self.label_path_list 对应出 idx 相应的 数据内容 和 其label
        img_idx_path = os.path.join( self.whole_img_dir , self.img_path_list[idx] )   #拼接得到idx对应的图片的全部地址,self.img_path_list[idx]只是图片的名字。  拼接后长这样: 'P567DataSet/train/ants/0013035.jpg'
        #label得从对应文件的地址中得到该txt文件,并读取文件内容
        label_idx_path = os.path.join(self.whole_label_dir,self.label_path_list[idx])       #idx对应的label文件的地址
        #根据文件地址,读取该文件中对image的label值:
        with open(label_idx_path, "r") as f:
            label_idx_content = f.readline()
            print("label为:", label_idx_content)

        return img_idx_path , label_idx_content


#验证部分
train_ants_data = MyDataSet("SecondTypeDataset/train","ants")
train_bees_data = MyDataSet("SecondTypeDataset/train","bees")

print(train_ants_data[0])

imgpathbee1 , labelbee1 = train_bees_data[0]
imgpathant1 , labelant1 = train_ants_data[0]

imgant1 = Image.open(imgpathant1)  #利用PIL中的Image库,输入图片路径,返回一个Image类型的实例
imgbee1 = Image.open(imgpathbee1)

imgbee1.show()  # 将Image实例imgbee1调用.show()函数展示该图片
imgant1.show()

ref:

1. Python中os.listdir的乱序问题的解决 :

.Python中os.listdir的乱序问题_陨星落云的博客-CSDN博客2

2.python对文件的读写操作 :

python读取、写入txt文本内容_洞幺01的博客-CSDN博客_python txt

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值