【Pytorch学习 3】PyTorch加载数据初认识

参考视频:

P5. PyTorch加载数据初认识_哔哩哔哩_bilibili

P6. Dataset类代码实战_哔哩哔哩_bilibili

1. dataset

1.1. 作用:

  • 获取数据&其label
  • 告知共有多少数据

1.2. 实战

1.2.1. 数据集准备

下载数据集:https://download.pytorch.org/tutorial/hymenoptera_data.zip

该数据集是蚂蚁和蜜蜂的二分类数据集,label就是文件名字,数据集结构如下

hymenoptera_data
  |__train
  |    |__ants
  |    |__bees
  |      
  |__val
       |__ants
       |__bees  

常见的数据集组织形式:

  • 文件名是label
  • label写在.txt里
  • 图片名就是label

1.2.2. python文件

from torch.utils.data import Dataset # 导入包
from PIL import Image # 用来读取图片
import os # 用来获取所有数据的地址

class MyData(Dataset):
    def __init__(self, root_dir, label_dir):
        #root_dir是train或val文件夹的地址
        #label_dir是ants或bees
        self.root_dir = root_dir
        self.label_dir = label_dir
        # 将root_dir和label_dir拼接在一起得到图片的地址eg:train/ants
        self.path = os.path.join(self.root_dir,self.label_dir) 
        # 得到图片地址下所有图片的名字列表eg:['1.jpg','2.jpg','3.jpg']
        self.img_path = os.listdir(self.path) 

    def __getitem__(self,idx):
        #某张图片的名字eg:'1.jpg'
        img_name = self.img_path[idx] 
        # 某张图片的地址eg:train/ants/1.jpg
        img_item_path = os.path.join(self.root_dir,self_label_dir,img_name) 
        # label是文件名eg:ants
        label = self.label_dir
        return img_item_path,label # 返回第idx张图片的地址和label

    def __len__(self):
        return len(self.img_path) #返回总共有多少张图片

root_dir = 'hymenoptera_data/train' # 输入自己的数据地址

ants_label_dir = 'ants'
bees_label_dir = 'bees'

ants_dataset = MyData(root_dir, ants_label_dir) #创建一个实例
bees_dataset = MyData(root_dir, bees_label_dir) #创建一个实例

train_dataset = ants_dataset + bees_dataset # 拼接两个数据集=总的训练集

2. dataloader

2.1. 作用:

为后面的网络提供不同的数据形式(让网络能够一个batch一个batch的处理数据)

问题:

  • 如何使用os、PIL
  • 不同的数据集组织形式如何加载
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值