完整目录
关于 Dataset 数据加载
- 如何理解数据的加载呢?
- 就是说,要想获取自己电脑里的数据,读取它,那么就要遵守 pytorch 加载数据的规则。
他的规则就是 定义一个class类,继承 Dataset (from torch.utils.data import Dataset),并且,在类中,定义三个函数,分别是:初始化 init、获得每一个数据 getitem、数据长度 len。 - 这里面的过程,要很清楚:
1、路径、合并路径、把文件夹中的每一个文件名称,做成一个列表(这是init要做的事情);
2、访问init中的列表,把列表的名称逐一传递给一个变量,命名为name,再次合并路径,并且把文件名连接在路径之后,接下来,用PIL中的Image.open函数,读取(加载)上述路径的文件(命名为img)(这里肯定是图像了),返回 图像img和标签 label(这是getitem的工作);
3、最后用len()返回列表的长度。 - 定义好 类 以后,后面就可以实例化这个类,定义参数(本例其实是一个路径,一个夹名称了),名称可以和定义类中的不一样,但是位置要对应(奥,这可能是Python课程里说的位置参数?)。
- 引用之前定义的类,把上述参数,传递进去。
- 最后打印自定义数据列表的长度。
可运行的代码
import os
from PIL import Image
from torch.utils.data import Dataset
class MyData(Dataset):
def __init__(self, root_dir, label_dir):
self.root_dir = root_dir
self.label_dir = label_dir
self.path = os.path.join(self.root_dir,self.label_dir)
self.img_path = os.listdir(self.path)
def __getitem__(self, idx):
img_name = self.img_path[idx]
img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)
img = Image.open(img_item_path)
label = self.label_dir
return img, label
def __len__(self):
return len(self.img_path)
if __name__ == '__main__':
root_dir = "F:\\PhD\\01-Python_In_One\\Project\\【B_up】XiaoTuDui\\data\\train"
ants_label_dir = "ants_image"
bees_label_dir = "bees_image"
ants_dataset = MyData(root_dir, ants_label_dir)
bees_dataset = MyData(root_dir, bees_label_dir)
train_dataset = ants_dataset + bees_dataset
完整代码 P6-7_read_data.py
"""
author :24nemo
date :2021年07月12日
"""
'''
Dataset:
能把数据进行编号
提供一种方式,获取数据,及其label,实现两个功能:
1、如何获取每一个数据,及其label
2、告诉我们总共有多少个数据
数据集的组织形式,有两种方式:
1、文件夹的名字,就是数据的label
2、文件名和label,分别处在两个文件夹中,label可以用txt的格式进行存储
在jupyter中,可以查看,help,两个方式:
1、help(Dataset)
2、Dataset??
Dataloader:
为网络提供不同的数据形式,比如将0、1、2、3进行打包
这一节内容很重要
'''
'''
# writer = SummaryWriter("logs")
class MyData(Dataset):
def __init__(self, root_dir, image_dir, label_dir, transform):
# 初始化,为这个函数用来设置在类中的全局变量
self.root_dir = root_dir
self.image_dir = image_dir
self.label_dir = label_dir
self.label_path = os.path.join(self.root_dir, self.label_dir)
self.image_path = os.path.join(self.root_dir, self.image_dir)
self.image_list = os.listdir(self.image_path)
self.label_list = os.listdir(self.label_path)
self.transform = transform
# 因为 label 和 Image文件名相同,进行一样的排序,可以保证取出的数据和label是一一对应的
self.image_list.sort()
self.label_list.sort()
def __getitem__(self, idx):
img_name = self.image_list[idx]
label_name = self.label_list[idx]
img_item_path = os.path.join(self.root_dir, self.image_dir, img_name)
label_item_path = os.path.join(self.root_dir, self.label_dir, label_name)
img = Image.open(img_item_path)
with open(label_item_path, 'r') as f:
label = f.readline()
# img = np.array(img)
img = self.transform(img)
sample = {'img': img, 'label': label}
return sample
def __len__(self):
# assert len(self.image_list) == len(self.label_list)
return len(self.image_list)
if __name__ == '__main__':
transform = transforms.Compose([transforms.Resize((256, 256)), transforms.ToTensor()])
root_dir = "dataset/train"
image_ants = "ants_image"
label_ants = "ants_label"
ants_dataset = MyData(root_dir, image_ants, label_ants, transform)
image_bees = "bees_image"
label_bees = "bees_label"
bees_dataset = MyData(root_dir, image_bees, label_bees, transform)
train_dataset = ants_dataset + bees_dataset
# transforms = transforms.Compose([transforms.Resize(256, 256)])
dataloader = DataLoader(train_dataset, batch_size=1, num_workers=2)
# writer.add_image('error', train_dataset[119]['img'])
# writer.close()
# for i, j in enumerate(dataloader):
# # imgs, labels = j
# print(type(j))
# print(i, j['img'].shape)
# # writer.add_image("train_data_b2", make_grid(j['img']), i)
# writer.close()
# jupyter notebook 等方法,可以查看 help
'''
'''
以下内容是视频中完全一样的代码,截图,在 20210713 的笔记中,包括 python console 的代码也有保存
'''
运行结果
完整目录