pytorch: 加载数据集模板

import torch
import numpy as np
from torch.utils.data import DataLoader


class Getloader(torch.utils.data.Dataset):

    def __init__(self,data_root,data_label):
        self.data = data_root
        self.label = data_label

    def __getitem__(self, index):
        data = self.data[index]
        labels = self.label[index]
        return  data , labels

    def __len__(self):
        return len(self.data)

这部分是读取数据集的模板,可以把他当成一个容器,这个要配合DataLoader使用发挥作用。

实例化:

source_data = '图片路径'
source_label= ’图片路径‘
data = Getloader(source_data,source_label)

DataLoader:

函数介绍:

torch.utils.data.DataLoader(dataset,batch_size,shuffle,drop_last,num_workers)
  1. dataset: 加载torch.utils.data.Dataset对象数据
  2. batch_size: 每个batch的大小
  3. shuffle:是否对数据进行打乱
  4. drop_last:是否对无法整除的最后一个datasize进行丢弃
  5. num_workers:表示加载的时候子进程数

实例化:

datas = DataLoader(data,batch_size=5,shuffle=True,drop_last=False)

查看实例化其结果:

for i, data in enumerate(datas):
	# i表示第几个batch, data表示该batch对应的数据,包含data和对应的labels
    print("第 {} 个Batch \n{}".format(i, data))

#这个是训练网络时调用的方法
for data , label in datas:
    print(data)
    print(label) #输出的是张量数组

这个DataLoader会返回Dataset里面的_getitem_中的return值,并且每次返回batch_size=5个值(这里的data),他会把这5个值放到一个序列里(列表、元组),这个5个data和5个label很好理解是下标对应的关系(故可以作为同一对训练集使用)。

当我们训练的数据集为图片时,dataset中_getitem_返回的要是张量,这时设置batch_size=1,则读取的就是单个的图片和对应的标签图片了。

注:用dataloader得到dataset里的图片自动由numpy转化为tensor,故这里不需要手动转化,而在预测的时候,你要手动转化一下

完整测试代码贴上去,大家可以试一下

import torch
from torch.utils.data import DataLoader
class Getloader(torch.utils.data.Dataset):

    def __init__(self,data_root,data_label):
        self.data = data_root
        self.label = data_label

    def __getitem__(self, index):
        data = self.data[index]
        labels = self.label[index]
        return  data ,labels

    def __len__(self):
        return len(self.data)

source_data = 'abcdefgh'
source_label= [1,2,3,4,5,6,7,8]
data = Getloader(source_data,source_label)
datas = DataLoader(data,batch_size=3,shuffle=True,drop_last=False)

for i ,da in datas:
        print('-----------')
        print(i)
        print(da)

 这是结果:

D:\Anaconda\envs\pytorch3.7\python.exe D:/project/pytorch/try/main.py
-----------
('g', 'h', 'c')
tensor([7, 8, 3])
-----------
('d', 'f', 'a')
tensor([4, 6, 1])
-----------
('b', 'e')
tensor([2, 5])

进程已结束,退出代码为 0

再贴一个可以使用的例子:

import torch
import cv2
import os
import glob
from torch.utils.data import Dataset


class ISBI_Loader(Dataset):
    def __init__(self, data_path):
        # 初始化函数,读取所有data_path下的图片
        self.data_path = data_path
        self.imgs_path = glob.glob(os.path.join(data_path, 'image2/T*.png'))    #得到一个list(遍历路径下的所有)
        self.labs_path = glob.glob(os.path.join(data_path, 'label2/P*.jpg'))

    def __getitem__(self, index):
        # 根据index读取图片
        image_path = self.imgs_path[index]
        # 根据image_path生成label_path
        label_path = self.labs_path[index]
        # 读取训练图片和标签图片

        image = cv2.imread(image_path)
        label = cv2.imread(label_path)


        return image, label

    def __len__(self):
        # 返回训练集大小
        return len(self.imgs_path)

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值