图片数据集存储为该形式
数据集转换:下一篇讲ray框架下pytorch模型训练时调用该模块
import os import cv2 import numpy as np import torch from torch.utils.data import Dataset import matplotlib.image as mpimg # 对所有图片生成path-label map.txt 这个程序可根据实际需要适当修改 def generate_map(root_dir, n): # root_dir为D:/tmp/photo # 得到当前绝对路径 current_path = os.path.abspath('data') # os.path.dirname()向前退一个路径 father_path = os.path.abspath(os.path.dirname(current_path) + os.path.sep + ".") for idx in range(n): subdir = os.path.join(root_dir, '%d/' % idx) all_name = [] for file_name in os.listdir(subdir): all_name.append(file_name) len_all_name = len(all_name) # 划分训练验证测试集 split_1 = int(len_all_name * 0.6) split_2 = int(len_all_name * 0.8) train_name = all_name[:split_1] val_name = all_name[split_1:split_2] test_name = all_name[split_2:] # 将训练、验证、测试集的路径和标签写入不同的txt文件 with open(os.path.join(root_dir, 'trainmap.txt'), 'w') as wfp1: for i in range(len(train_name)): abs_name = os.path.join(father_path, subdir, train_name[i]) # linux_abs_name = abs_name.replace("\\", '/') wfp1.write('{file_dir} {label}\n'.format(file_dir=abs_name, label=idx)) with open(os.path.join(root_dir, 'valmap.txt'), 'w') as wfp2: for i in range(len(val_name)): abs_name = os.path.join(father_path, subdir, val_name[i]) # linux_abs_name = abs_name.replace("\\", '/') wfp2.write('{file_dir} {label}\n'.format(file_dir=abs_name, label=idx)) with open(os.path.join(root_dir, 'testmap.txt'), 'w') as wfp3: for i in range(len(test_name)): abs_name = os.path.join(father_path, subdir, test_name[i]) # linux_abs_name = abs_name.replace("\\", '/') wfp3.write('{file_dir} {label}\n'.format(file_dir=abs_name, label=idx)) # 实现MyDatasets类 class MyDatasets(Dataset): def __init__(self, dir, method): # 获取数据存放的dir # 例如d:/images/ self.data_dir = dir # 用于存放(image,label) tuple的list,存放的数据例如(d:/image/1.png,4) self.image_target_list = [] # 从dir--label的map文件中将所有的tuple对读取到image_target_list中 # map.txt中全部存放的是d:/.../image_data/1/3.jpg 1 路径最好是绝对路径 with open(os.path.join(dir, method), 'r') as fp: content = fp.readlines() # s.rstrip()删除字符串末尾指定字符(默认是字符) # 得到 [['d:/.../image_data/1/3.jpg', '1'], ...,] str_list = [s.rstrip().split() for s in content] # 将所有图片的dir--label对都放入列表,如果要执行多个epoch,可以在这里多复制几遍,然后统一shuffle比较好 self.image_target_list = [(x[0], int(x[1])) for x in str_list] def __getitem__(self, index): image_label_pair = self.image_target_list[index] # 按path读取图片数据,并转换为图片格式例如[3,32,32] # 可以用别的代替 img = mpimg.imread(image_label_pair[0]) img = np.resize(img, (3, 32, 32)) img = torch.from_numpy(img).float() return img, image_label_pair[1] def __len__(self): return len(self.image_target_list)