【PyTorch教程】P6-P7 数据加载

完整目录

关于 Dataset 数据加载

  • 如何理解数据的加载呢?
  • 就是说,要想获取自己电脑里的数据,读取它,那么就要遵守 pytorch 加载数据的规则。
    他的规则就是 定义一个class类,继承 Dataset (from torch.utils.data import Dataset),并且,在类中,定义三个函数,分别是:初始化 init、获得每一个数据 getitem、数据长度 len。
  • 这里面的过程,要很清楚:
    1、路径、合并路径、把文件夹中的每一个文件名称,做成一个列表(这是init要做的事情);
    2、访问init中的列表,把列表的名称逐一传递给一个变量,命名为name,再次合并路径,并且把文件名连接在路径之后,接下来,用PIL中的Image.open函数,读取(加载)上述路径的文件(命名为img)(这里肯定是图像了),返回 图像img和标签 label(这是getitem的工作);
    3、最后用len()返回列表的长度。
  • 定义好 以后,后面就可以实例化这个类,定义参数(本例其实是一个路径,一个夹名称了),名称可以和定义类中的不一样,但是位置要对应(奥,这可能是Python课程里说的位置参数?)。
  • 引用之前定义的类,把上述参数,传递进去。
  • 最后打印自定义数据列表的长度。

可运行的代码

import os

from PIL import Image
from torch.utils.data import Dataset


# dataset有两个作用:1、加载每一个数据,并获取其label;2、用len()查看数据集的长度
class MyData(Dataset):
    def __init__(self, root_dir, label_dir):  # 初始化,为这个函数用来设置在类中的全局变量
        self.root_dir = root_dir
        self.label_dir = label_dir
        self.path = os.path.join(self.root_dir,self.label_dir)  # 单纯的连接起来而已,背下来怎么用就好了,因为在win下和linux下的斜线方向不一样,所以用这个函数来连接路径
        self.img_path = os.listdir(self.path)  # img_path 的返回值,就已经是一个列表了

    def __getitem__(self, idx):  # 获取数据对应的 label
        img_name = self.img_path[idx]  # img_name 在上一个函数的最后,返回就是一个列表了
        img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)  # 这行的返回,是一个图片的路径,加上图片的名称了,能够直接定位到某一张图片了
        img = Image.open(img_item_path)  # 这个步骤看来是不可缺少的,要想 show 或者 操作图片之前,必须要把图片打开(读取),也就是 Image.open()一下,这可能是 PIL 这个类型图片的特有操作
        label = self.label_dir  # 这个例子中,比较特殊,因为图片的 label 值,就是图片所在上一级的目录
        return img, label  # img 是每一张图片的名称,根据这个名称,就可以使用查看(直接img)、print、size等功能
        # label 是这个图片的标签,在当前这个类中,标签,就是只文件夹名称,因为我们就是这样定义的

    def __len__(self):
        return len(self.img_path)  # img_path,已经是一个列表了,len()就是在对这个列表进行一些操作


if __name__ == '__main__':
    root_dir = "F:\\PhD\\01-Python_In_One\\Project\\【B_up】XiaoTuDui\\data\\train"
    # root_dir = "data/train"
    ants_label_dir = "ants_image"
    bees_label_dir = "bees_image"
    ants_dataset = MyData(root_dir, ants_label_dir)
    bees_dataset = MyData(root_dir, bees_label_dir)
    train_dataset = ants_dataset + bees_dataset

完整代码 P6-7_read_data.py

# from torch.utils.tensorboard import SummaryWriter

# !usr/bin/env python3
# -*- coding:utf-8 -*-

"""
author :24nemo
 date  :2021年07月12日
"""

'''
Dataset:
能把数据进行编号
提供一种方式,获取数据,及其label,实现两个功能:
1、如何获取每一个数据,及其label
2、告诉我们总共有多少个数据

数据集的组织形式,有两种方式:
1、文件夹的名字,就是数据的label
2、文件名和label,分别处在两个文件夹中,label可以用txt的格式进行存储

在jupyter中,可以查看,help,两个方式:
1、help(Dataset)
2、Dataset??

Dataloader:
为网络提供不同的数据形式,比如将0、1、2、3进行打包

这一节内容很重要
'''
'''
# writer = SummaryWriter("logs")

class MyData(Dataset):
    def __init__(self, root_dir, image_dir, label_dir, transform):
        #  初始化,为这个函数用来设置在类中的全局变量
        self.root_dir = root_dir
        self.image_dir = image_dir
        self.label_dir = label_dir
        self.label_path = os.path.join(self.root_dir, self.label_dir)
        self.image_path = os.path.join(self.root_dir, self.image_dir)
        self.image_list = os.listdir(self.image_path)
        self.label_list = os.listdir(self.label_path)
        self.transform = transform
        # 因为 label 和 Image文件名相同,进行一样的排序,可以保证取出的数据和label是一一对应的
        self.image_list.sort()
        self.label_list.sort()

    def __getitem__(self, idx):
        img_name = self.image_list[idx]
        label_name = self.label_list[idx]
        img_item_path = os.path.join(self.root_dir, self.image_dir, img_name)
        label_item_path = os.path.join(self.root_dir, self.label_dir, label_name)
        img = Image.open(img_item_path)

        with open(label_item_path, 'r') as f:
            label = f.readline()

        # img = np.array(img)
        img = self.transform(img)
        sample = {'img': img, 'label': label}
        return sample

    def __len__(self):
        # assert len(self.image_list) == len(self.label_list)
        return len(self.image_list)


if __name__ == '__main__':
    transform = transforms.Compose([transforms.Resize((256, 256)), transforms.ToTensor()])
    root_dir = "dataset/train"
    image_ants = "ants_image"
    label_ants = "ants_label"
    ants_dataset = MyData(root_dir, image_ants, label_ants, transform)
    image_bees = "bees_image"
    label_bees = "bees_label"
    bees_dataset = MyData(root_dir, image_bees, label_bees, transform)
    train_dataset = ants_dataset + bees_dataset

    # transforms = transforms.Compose([transforms.Resize(256, 256)])
    dataloader = DataLoader(train_dataset, batch_size=1, num_workers=2)

    # writer.add_image('error', train_dataset[119]['img'])
    # writer.close()
    # for i, j in enumerate(dataloader):
    #     # imgs, labels = j
    #     print(type(j))
    #     print(i, j['img'].shape)
    #     # writer.add_image("train_data_b2", make_grid(j['img']), i)
    
    # writer.close()

#  jupyter notebook 等方法,可以查看 help
'''

'''
以下内容是视频中完全一样的代码,截图,在 20210713 的笔记中,包括 python console 的代码也有保存
'''

运行结果

在这里插入图片描述

完整目录

  • 61
    点赞
  • 112
    收藏
    觉得还不错? 一键收藏
  • 7
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值