百度飞桨,读取数据类,class Dataloader

dataloader类有四个参数:image_folder(文件路径),image_list_file(文路径名文件),transform=None(是否变化),shuffle=True(是否打乱)

五个函数:

__init__(初始化参数和函数方法),

read_list(在一个文件夹中读取image和data的路径,存入list[]中),

preprocess(判断data和label的shape是否一致),

__len__(读取list[]的长度),

__call__(读出list[]中所有data和label,并输出其shape,返回data和label)


#读取我们要的图和label,然后做一个基础的处理,最后把他们返回过来
#class transform还没搞懂,transform和dataloader的参数transform有什么关系还不知道
import os
import random
import numpy as np
import cv2
import paddle.fluid as fluid

class Transform(object):
    def __init__(self,size = 256):
        self.size = size
    def __call__(self,input,label):#写面向对象好扩展一些
        input =cv2.resize(input,(self.size,self.size),interpolation = cv2.INTER_LINEAR)
        label =cv2.resize(input,(self.size,self.size),interpolation = cv2.INTER_NEAREST)
        #不用差值用nearest可以让label不会有任何越界
        return input,label


class BasicDataLoader(object):
    def __init__(self,
                 image_folder,
                 image_list_file,
                 transform=None,
                 shuffle=True):
        #将参数变成成员变量
        self.image_folder = image_folder
        self.image_list_file =image_list_file
        self.transform=transform
        self.shuffle= shuffle
        self.data_list = self.read_list()#获取list[],里面是数据集样本路径
    
    def read_list(self):#读取函数,返回一个list()
        data_list = []
        with open(self.image_list_file) as infile:#报错
            for line in infile:
                data_path = os.path.join(self.image_folder,line.split()[0])
                label_path = os.path.join(self.image_folder, line.split()[1])
                data_list.append((data_path, label_path))
        random.shuffle(data_list)
        return data_list

    def preprocess(self, data, label):#将图像大小变成标准大小
        h, w, c = data.shape
        h_gt, w_gt = label.shape
        assert h == h_gt, "Error"
        assert w == w_gt, "Error"
        if self.transform:
            data, label = self.transform(data, label)
        label = label[:, :, np.newaxis]
        return data, label

    def __len__(self):#复习len()基础的函数
        return len(self.data_list)

    def __call__(self):
        for data_path,label_path in self.data_list:
            data = cv2.imread(data_path, cv2.IMREAD_COLOR)#用opencv读出来
            data = cv2.cvtColor(data,cv2.COLOR_BGR2RGB)
            label = cv2.imread(label_path, cv2.IMREAD_GRAYSCALE)#读出单通道
            print(data.shape,label.shape)
            data, label = self.preprocess(data, label)#验证数据集是否正确,如果不正确用transform变成正确格式
            yield data,label#帮我returnhwc=data.shape,h_gt,w_gt =label.shape,assert h=h_gt,"error"



def main():
    

    batch_size = 5
    place = fluid.CPUPlace()
    with fluid.dygraph.guard(place):
        transform = Transform(256)
       
        basic_dataloader = BasicDataLoader(
            image_folder = 'work/dummy_data/',
            image_list_file = 'work/dummy_data/list.txt',
            transform = transform,
            shuffle = True
            )
        # # create fluid.io.Dataloader instance
        dataloader = fluid.io.DataLoader.from_generator(capacity=1, use_multiprocess=False)
        
        # set sample generator for fluid dataloader
        dataloader.set_sample_generator(basic_dataloader,                       #python迭代器
                                        batch_size=batch_size,
                                        places=place)      
   
        num_epoch = 5
        for epoch in range(1, num_epoch+1):
            print(f'Epoch[{epoch}/{num_epoch}]:')
            for idx, (data, label) in enumerate(dataloader):                #idx是索引,使用enumerate可以返回索引
                print(f'iter {idx}, Data shape: {data.shape}, Label shape:{label.shape}')


if __name__ == "__main__":
    main()

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值