pytorch读取数据主要涉及到两个类:dataset和dataload,以读取图片为例子
dataset主要包括三个类:
首先初始化图片位置,一般写在初始化函数中
def __init__(self, root_dir, image_dir, label_dir):
self.root_dir = root_dir
self.image_dir = image_dir
self.label_dir = label_dir
self.image_path = os.path.join(self.root_dir,self.image_dir)
self.label_path = os.path.join(self.root_dir,self.label_dir)
self.img_list = os.listdir(self.image_path)
self.label_list = os.listdir(self.label_path)
self.img_list.sort()
self.label_list.sort()
然后是根据图片位置去读取图片并返回读取的图片和标签:
def __getitem__(self, item):
img_name = self.img_list[item]
label_name = self.label_list[item]
img_name_path = os.path.join(self.root_dir,self.image_dir,img_name)
label_name_path = os.path.join(self.root_dir,self.label_dir,label_name)
img = Image.open(img_name_path)
with open(label_name_path,'r') as f:
label = f.readline()
return img,label
最后返回数据集的长度,主要用于后续dataloader载入网络:
def __len__(self):
assert len(self.img_list) == len(self.label_list)
return len(self.img_list)
自定义数据集到这就初步完成了后续还有利用transform进行数据增强以及dataloader载入数据到网络中进行训练