记录一下pytorch读取大型数据集的要点
# pytorch 读取大数据集的一般方法
class mydataset(Data.Dataset):
def __init__(self,root='filepath'):
self.root = root
# __init__ 中读取文件路径而非文件本体
self.imgs_list = ...
self.labs_list = ...
def __getitem__(self,index):
img_path,lab = self.imgs_list[index],self.labs_list[index]
# __getitem__中读取文件,随取随用,避免内存占用过大
img_data = readimg(img_path)
if self.is_transform:
imgdata = torchvision.transforms.ToTensor
# torchvision 进行数据转换,速度快
return imgdata lab
def __len__(self):
return len(self.imgs_list)