P6-P7 数据加载+Dataset类代码实战
P6 数据加载
Dataset 和Dataloader区别
P7 Dataset类代码实战
写在前面:
a.安装opencv-python:
pip install opencv-python
import cv2
b.用pil或安装pillow (本篇用法)
from PIL import Image
MyData继承Dataset,需要定义三个函数
1、def init(self)
初始化,为这个函数用来设置在类中的全局变量
任务:获取路径、合并路径、地址列表化
class MyData(Dataset):
# 为后面的函数提供一些量
def __init__(self, root_dir, label_dir):
self.root_dir = root_dir
self.label_dir = label_dir
# 获取地址
self.path = os.path.join(self.root_dir, self.label_dir)
# 地址列表化
self.img_path = os.listdir(self.path) # 获取地址后,将文件夹中所有数据以列表形式存入
self:相当于一个船,函数的变量相当于一个人,只有变量放船上才能给别的函数用
2、 def getitem(self, idx)
获取各个图
工作:访问init中的列表,把列表的名称逐一传递给一个变量,再次合并路径,并且把文件名连接在路径,从而获取行对路径。之后用PIL中的Image.open函数读取(加载)上述路径的文件,返回 图像img和标签 label
# 获取各个图
def __getitem__(self, idx):
img_name = self.img_path[idx]
# 获取相对路径
img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)
img = Image.open(img_item_path) # ps:操作图片前都需要打开open一下
label = self.label_dir # 图片上一级的目录
return img, label # 图的名称 和 图片标签(当前类中,标签为文件夹名称)
img:图的名称,由此可以使用 查看、print、size等功能
label:图片标签,当前类中,标签为文件夹名称(就这么定义的)
3、def len(self)
计算数据长度
def __len__(self):
return len(self.img_path)
root_dir = "hymenoptera_data/train"
ants_label_dir = "ants"
bees_label_dir = "bees"
# 创建实例
ants_dataset = MyData(root_dir, ants_label_dir)
bees_dataset = MyData(root_dir, bees_label_dir)
train_dataset = ants_dataset + bees_dataset
完整代码
from torch.utils.data import Dataset
import os
from PIL import Image
class MyData(Dataset): # 继承Dataset
# 为后面的函数提供一些量
def __init__(self, root_dir, label_dir):
self.root_dir = root_dir
# self:相当于一个船,函数的变量相当于一个人,只有变量放船上才能给别的函数用
self.label_dir = label_dir
# 获取地址
self.path = os.path.join(self.root_dir, self.label_dir)
# 地址列表化
self.img_path = os.listdir(self.path) # 获取地址后,将文件夹中所有数据以列表形式存入
# 获取各个图
def __getitem__(self, idx):
img_name = self.img_path[idx]
# 获取相对路径
img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)
img = Image.open(img_item_path) # ps:操作图片前都需要打开open一下
label = self.label_dir # 图片上一级的目录
return img, label # 图的名称 和 图片标签(当前类中,标签为文件夹名称)
#
def __len__(self):
return len(self.img_path)
root_dir = "hymenoptera_data/train"
ants_label_dir = "ants"
bees_label_dir = "bees"
# 创建实例
ants_dataset = MyData(root_dir, ants_label_dir)
bees_dataset = MyData(root_dir, bees_label_dir)
train_dataset = ants_dataset + bees_dataset