from torch.utils.data import Dataset
from PIL import Image
import os
class MyData(Dataset):
def __init__(self,root_dir,label_dir):
self.root_dir = root_dir
self.label_dir = label_dir
self.path = os.path.join(self.root_dir,self.label_dir)
self.img_path = os.listdir(self.path)
def __getitem__(self, item):
img_name = self.img_path[item]
img_item_path = os.path.join(self.root_dir,self.label_dir,img_name)
img = Image.open(img_item_path)
label = self.label_dir
return img,label
def __len__(self):
return len(self.img_path)
root_dir = "dataset/train"
ants_label_dir = "ants"
bees_label_dir = "bees"
ants_dataset = MyData(root_dir,ants_label_dir)
bees_dataset = MyData(root_dir,bees_label_dir)
train_dataset = ants_dataset+bees_dataset
img,label = bees_dataset[0]
img.show()
代码分析 :
-
self
参数:- 这是Python类方法的标准第一个参数。
- 代表类的实例本身,允许在方法内部访问和修改实例的属性。
-
root_dir
参数:- 用于指定数据集的根目录路径。
- 允许用户在创建类实例时灵活指定数据所在的主目录。
-
label_dir
参数:- 用于指定标签(或类别)目录的名称。
- 使得类可以处理不同的标签子目录
-
self.path = os.path.join(self.root_dir, self.label_dir)
想象你在整理照片。
root_dir
是你的相册柜,label_dir
是具体的相册名。这行代码就像你在说"我要找的照片在相册柜的这本相册里"。os.path.join()
就是帮你正确地指明位置,不管你用的是Windows还是Mac。 -
self.img_path = os.listdir(self.path)
继续用整理照片的比喻。这行代码就像你打开那本相册,快速浏览了一遍,在脑子里记下了所有照片的名字。
os.listdir()
就是帮你"浏览"并列出所有文件名,self.img_path
就是你记下的这个"名单"。