Dataset:
提供一种方式获取数据及其label
如何获取每一个数据及其label
告诉我们总共有多少个数据
1.首先导入包
from torch.utils.data import Dataset
from PIL import Image
import os
2.写一个类 ,继承Dataset。
class Mydata(Dataset):
def __init__(self, root_dir, laber_dir):
self.root_dir = root_dir
self.laber_dir = laber_dir
self.data_path = os.path.join(root_dir, laber_dir)
self.data_list = os.listdir(self.data_path)
def __getitem__(self, index):
return os.path.join(self.root_dir, self.laber_dir, self.data_list[index])
def __len__(self):
return len(self.data_list)
①def __init__(self, root_dir, laber_dir): 在创建类的时候,实现初始化。
root_dir:根目录
laber_dir:目标目录
通过data_path=os.path.join(root_dir,laber_dir)实现路径组合连接,使windows和linux都可以使用。
os.listdir[data_path]:会生成该文件夹内所有图片的"集合"
②重写def __getitem__(self, index):
返回下标为index图片的详细说明“路径+文件名”
实例化:
①根据下标查找单个数据集
ants_data = Mydata("dataset/train", "ants")
ants_data .__getitem__(2)返回第三张图片的信息 ants_data[2]。返回第三张图片的信息。[] 运算符取值时,会调用它的方法__getitem__
② 数据集的拼接
ants_data = Mydata("dataset/train", "ants")
bees_data = Mydata("dataset/train", "bees")
new_data = ants_data + bees_data
new_data就是拼接后两个的数据集。
③图片打开并查看
img = Image.open(my[5]) img.show()