from PIL import Image
from torch.utils.data import Dataset
import os
class mydata(Dataset):
def __init__(self,root_dir,label_dir):
self.root_dir=root_dir
self.label_dir=label_dir
self.path=os.path.join(self.root_dir,self.label_dir)#获取文件夹下的目录的所有文件
self.img_path=os.listdir(self.path)
def __getitem__(self,idx):
img_name=self.img_path[idx]#每个单个文件
img_item_path=os.path.join(self.root_dir,self.label_dir,img_name)#每个单个文件的地址
label=self.label_dir
img_show=Image.open(self.img_item_path)
return label,img_show
root_dir='获取的文件夹'
label_dir='文件夹下的目录'
img=mydata(root_dir,label_dir)
img[0]
其中mydata为继承Dataset的类