Dataset demo
-
- 前置基础知识
os库的使用
import os dir_path = "dataset/train/ants" img_path_list = os.listdir(dir_path) #此时img_path_list为文件列表 img_path_list[0] >>> 'xxx.jpg'
- 前置基础知识
-
- 本文使用的数据集样式
- 本文使用的数据集样式
其中 i m a g e image image文件夹中为jpg格式图片, l a b e l label label文件夹中为txt文本。
一般设置 r o o t _ d i r root\_dir root_dir为所有数据集的根目录, 设置 t r a i n i m g train_img trainimg为数据相对根目录的地址,设置 l a b e l _ d i r label\_dir label_dir为标签相对根目录的地址(名字随意,自己能区分就好)。
使用 o s . p a t h . j o i n ( r o o t _ d i r , l a b e l _ d i r ) os.path.join(root\_dir, label\_dir) os.path.join(root_dir,label_dir)函数将他们拼接起来。
-
- demo
from torch.utils.data import Dataset
import os
from PIL import Image
class MyData(Dataset):
# 数据类初始化
def __init__(self, root_dir, train_img_dir_name, label_dir):
# 定义数据主文件夹地址
self.root_dir = root_dir
# 定义标签完整地址(根据不同数据做调整)
self.label_dir = os.path.join(self.root_dir, label_dir)
# 定义图片文件夹路径名
self.train_img_dir_name = train_img_dir_name
# 地址拼接
self.path = os.path.join(self.root_dir, self.train_img_dir_name)
# 数据列表
self.img_path = os.listdir(self.path)
self.img_label = os.listdir(self.label_dir)
# 得到单个数据(图片)
def __getitem__(self, idx):
# 从列表中得到单一数据的name
img_name = self.img_path[idx]
# 获取标签
img_label = self.img_label[idx]
# 得到数据地址
img_item_path = os.path.join(self.root_dir, self.train_img_dir_name, img_name)
# 使用PIL中的Image库打开图片
img = Image.open(img_item_path)
# 获取标签,从对应的txt文件中读取标签名
img_item_label = os.path.join(self.label_dir, img_label)
self.label = ""
with open(img_item_label, "r") as f:
self.label = f.read()
return img, self.label
# 得到数据大小
def __len__(self):
return len(self.img_path)
# 创建一个实例
root_dir = 'dataset/train/'
ants_train_img = 'ants_image'
ants_label_dir = "ants_label"
ants_dataset = MyData(root_dir, ants_train_img, ants_label_dir)
# 创建第二个实例
bees_train_img = 'bees_image'
bees_label_dir = 'bees_label'
bees_dataset = MyData(root_dir, bees_train_img, bees_label_dir)
# 测试是否得到蚂蚁图片
# img, label = ants_dataset[0]
# img.show()
# print(label)
# print(len(ants_dataset))
# 测试是否得到蜜蜂图片
# img, label = bees_dataset[0]
# img.show()
# print(label)
# print(len(bees_dataset))
# 定义完整训练数据集
train_dataset = ants_dataset + bees_dataset
print(len(train_dataset))
# 测试
img, label = train_dataset[124]
img.show()
print(label)
print(len(train_dataset))