目的:上一小节将cifar10的图片存储为png格式后,又划分了训练集、验证集和测试集。这一小节的目的是为了让PyTorch能读取我们的数据集。
首先读取图片的路径、标签并将其保存到txt文件中
1_3_generate_txt
# coding:utf-8
import os
'''
为数据集生成对应的txt文件
'''
train_txt_path = os.path.join("..", "..", "Data", "train.txt")
train_dir = os.path.join("..", "..", "Data", "train")
valid_txt_path = os.path.join("..", "..", "Data", "valid.txt")
valid_dir = os.path.join("..", "..", "Data", "valid")
def gen_txt(txt_path, img_dir):
f = open(txt_path, 'w')
for root, s_dirs, _ in os.walk(img_dir, topdown=True): # 获取 train文件下各文件夹名称
for sub_dir in s_dirs:
i_dir = os.path.join(root, sub_dir) # 获取各类的文件夹 绝对路径
img_list = os.listdir(i_dir) # 获取类别文件夹下所有png图片的路径
for i in range(len(img_list)):
if not img_list[i].endswith('png'): # 若不是png文件,跳过
continue
label = img_list[i].split('_')[0]
img_path = os.path.join(i_dir, img_list[i])
line = img_path + ' ' + label + '\n'
f.write(line)
f.close()
if __name__ == '__main__':
gen_txt(train_txt_path, train_dir)
gen_txt(valid_txt_path, valid_dir)
①.os.listdir()
输入:目录路径
输出:该路径下的文件和文件夹列表
例如代码中当i_dir为\Data\train\0时,img_list返回的就是类别为0的所有图片名称
其次是构建自己的Dataset子类
1_3_mydataset
# coding: utf-8
from PIL import Image
from torch.utils.data import Dataset
class MyDataset(Dataset):
def __init__(self, txt_path, transform=None, target_transform=None):
fh = open(txt_path, 'r')
imgs = []
for line in fh:
line = line.rstrip()
words = line.split()
imgs.append((words[0], int(words[1])))
self.imgs = imgs # 最主要就是要生成这个list, 然后DataLoader中给index,通过getitem读取图片数据
self.transform = transform
self.target_transform = target_transform
def __getitem__(self, index):
fn, label = self.imgs[index]
img = Image.open(fn).convert('RGB') # 像素值 0~255,在transfrom.totensor会除以255,使像素值变成 0~1
if self.transform is not None:
img = self.transform(img) # 在这里做transform,转为tensor等等
return img, label
def __len__(self):
return len(self.imgs)
代码内容概述:
MyDataset类中包含三个函数。初始化函数是将txt中的图片信息存储到列表self.imgs中,该列表中的每个元素都包含一张图片的地址和该图片对应标签。__getitem__()函数是输入一个索引index,返回该索引对应的图片和标签。__len__()函数返回图片总数。
不懂的地方:
DataLoader类和MyDataset类的作用