像这样,图片分成了两个文件夹,一个是train,一个是val,如果数据集过大,可以自定义一个类读取数据集。
标签有两种存贮方式
# 1.文件名为标签 # 2.写有对应的txt文本里有每一个图片所对应的标签
如果文件名为标签,并且类别较多,我们可以用一个列表去存贮这个类别
import os label = [] root_dir = "E:\\pycharm_data\\read_Data1\\hymenoptera_data\\train" img_label = os.listdir(root_dir) for i in img_label: label.append(i) print(label)
# 读取数据集的包
from torch.utils.data import Dataset
from PIL import Image
import os
class MyData(Dataset):
def __init__(self,root_dir,label_dir):
#os.listdir()方法用于返回指定的文件夹包含的文件或文件夹的名字的列表。
#os.path.join()是将根目录和子目录用/加起来,
#会自动根据Linux或Windows系统不同做不同的拼接
self.root_dir = root_dir
self.label_dir = label_dir
self.path = os.path.join(self.root_dir, self.label_dir)
self.img_path = os.listdir(self.path)
# idx是指每个图片的序号,我们会把图片序号做成一个列表
# 这个函数是用来从文件目录根据idx索引值读取其中一个文件
def __getitem__(self, idx):
#从文件列表取中相应索引位置的图片名称(注意并不是图片本身,只是名字)
img_name = self.img_path[idx]
# img_item_path就是取索引值的图片的路径
img_item_path = os.path.join(self.path, img_name)
# img就是取索引值的图片
img = Image.open(img_item_path)
label = self.label_dir
return img, label
def __len__(self):
return len(self.img_path)
root_dir = "E:\\test_wangge\\read_Data1\\hymenoptera_data\\train"
ants_label_dir = "ants"
bees_label_dir = "bees"
ants_dataset = MyData(root_dir, ants_label_dir)
bees_dataset = MyData(root_dir,bees_label_dir)
# 如果在类中定义了__getitem__()方法,那么他的实例对象(假设为P)就可以这样P[key]取值。
# 当实例对象做P[key]运算时,就会调用类中的__getitem__()方法。
img, label = ants_dataset[0]
# 可以把两个数据集加起来
train_dataset = ants_dataset + bees_dataset
print(len(ants_dataset))
print(len(bees_dataset))
print(len(train_dataset))