写一个系列代码实战,争取每天都更。
倒逼自己赶紧提升写 Python 代码的手感。
一、代码
import os
from PIL import Image
from torch.utils.data import Dataset
class Nemo(Dataset):
def __init__(self, root_dir, label_dir):
self.root_dir = root_dir
self.label_dir = label_dir
self.path = os.path.join(self.root_dir, self.label_dir)
self.img_path_list = os.listdir(self.path)
def __getitem__(self, idx):
image_name = self.img_path_list[idx]
image_item_path = os.path.join(self.root_dir, self.label_dir, image_name)
img = Image.open(image_item_path)
label = self.label_dir
return img, label
def __len__(self):
return len(self.img_path_list)
root_dir = "D:\\Python_In_One\\Project\\XiaoTuDui\\data\\train"
ants_label_dir = "ants_label"
ants_dataset = Nemo(root_dir, ants_label_dir)
print(len(ants_dataset))
二、理解
如何理解数据的加载呢?
就是说,要想获取自己电脑里的数据,读取它,那么就要遵守 pytorch 加载数据的规则。
他的规则就是 定义一个 类,继承 Dataset (from torch.utils.data import Dataset),并且,在类中,定义三个函数,分别是 初始化 init、获得每一个数据 getitem、数据长度 len。
这里面的过程,要很清楚:
1、 路径、合并路径、把文件夹中的每一个文件名称,做成一个列表(这是init要做的事情);
2、 访问init中的列表,把列表的名称逐一传递给一个变量,命名为name,再次合并路径,并且把文件名连接在路径之后,接下来,用PIL中的Image.open函数,读取(加载)上述路径的文件(命名为img)(这里肯定是图像了),返回 图像img和标签 label(这是getitem的工作);
3、 最后用len()返回列表的长度。
定义好 类 以后,后面就可以实例化这个类,定义参数(本例其实是一个路径,一个夹名称了),名称可以和定义类中的不一样,但是位置要对应(奥,这可能是Python课程里说的位置参数?)。
引用之前定义的类,把上述参数,传递进去。
最后打印自定义数据列表的长度。
参考内容
该案例是 上手学习 PyTorch 时,B站 up 【我是土堆】的代码实战。
提醒自己
在我的文件夹中,
文件名为:P6_7_read_data.py