dataloader类有四个参数:image_folder(文件路径),image_list_file(文路径名文件),transform=None(是否变化),shuffle=True(是否打乱)
五个函数:
__init__(初始化参数和函数方法),
read_list(在一个文件夹中读取image和data的路径,存入list[]中),
preprocess(判断data和label的shape是否一致),
__len__(读取list[]的长度),
__call__(读出list[]中所有data和label,并输出其shape,返回data和label)
#读取我们要的图和label,然后做一个基础的处理,最后把他们返回过来
#class transform还没搞懂,transform和dataloader的参数transform有什么关系还不知道
import os
import random
import numpy as np
import cv2
import paddle.fluid as fluid
class Transform(object):
def __init__(self,size = 256):
self.size = size
def __call__(self,input,label):#写面向对象好扩展一些
input =cv2.resize(input,(self.size,self.size),interpolation = cv2.INTER_LINEAR)
label =cv2.resize(input,(self.size,self.size),interpolation = cv2.INTER_NEAREST)
#不用差值用nearest可以让label不会有任何越界
return input,label
class BasicDataLoader(object):
def __init__(self,
image_folder,
image_list_file,
transform=None,
shuffle=True):
#将参数变成成员变量
self.image_folder = image_folder
self.image_list_file =image_list_file
self.transform=transform
self.shuffle= shuffle
self.data_list = self.read_list()#获取list[],里面是数据集样本路径
def read_list(self):#读取函数,返回一个list()
data_list = []
with open(self.image_list_file) as infile:#报错
for line in infile:
data_path = os.path.join(self.image_folder,line.split()[0])
label_path = os.path.join(self.image_folder, line.split()[1])
data_list.append((data_path, label_path))
random.shuffle(data_list)
return data_list
def preprocess(self, data, label):#将图像大小变成标准大小
h, w, c = data.shape
h_gt, w_gt = label.shape
assert h == h_gt, "Error"
assert w == w_gt, "Error"
if self.transform:
data, label = self.transform(data, label)
label = label[:, :, np.newaxis]
return data, label
def __len__(self):#复习len()基础的函数
return len(self.data_list)
def __call__(self):
for data_path,label_path in self.data_list:
data = cv2.imread(data_path, cv2.IMREAD_COLOR)#用opencv读出来
data = cv2.cvtColor(data,cv2.COLOR_BGR2RGB)
label = cv2.imread(label_path, cv2.IMREAD_GRAYSCALE)#读出单通道
print(data.shape,label.shape)
data, label = self.preprocess(data, label)#验证数据集是否正确,如果不正确用transform变成正确格式
yield data,label#帮我returnhwc=data.shape,h_gt,w_gt =label.shape,assert h=h_gt,"error"
def main():
batch_size = 5
place = fluid.CPUPlace()
with fluid.dygraph.guard(place):
transform = Transform(256)
basic_dataloader = BasicDataLoader(
image_folder = 'work/dummy_data/',
image_list_file = 'work/dummy_data/list.txt',
transform = transform,
shuffle = True
)
# # create fluid.io.Dataloader instance
dataloader = fluid.io.DataLoader.from_generator(capacity=1, use_multiprocess=False)
# set sample generator for fluid dataloader
dataloader.set_sample_generator(basic_dataloader, #python迭代器
batch_size=batch_size,
places=place)
num_epoch = 5
for epoch in range(1, num_epoch+1):
print(f'Epoch[{epoch}/{num_epoch}]:')
for idx, (data, label) in enumerate(dataloader): #idx是索引,使用enumerate可以返回索引
print(f'iter {idx}, Data shape: {data.shape}, Label shape:{label.shape}')
if __name__ == "__main__":
main()