Pytorch入门之——数据加载总结
使用torch将数据加载到数据加载器中可以分为以下三类:
1.标签在文件夹上,可以通过ImageFolder载入。
2.标签在图片名上,可以通过继承dataset类载入。
3.标签是CSV文件,可以通过合适的文件拆分函数,拆成第一种/第二中情况。
ImageFolder和Dataset都可以把数据转换为Dataset对象,而DataLoader的作用则是把Dataset对象转换为mini-batch数据并提供迭代器。
1. 标签在文件夹上
这种类型的数据可以调用ImageFolder来载入,步骤可以分为以下三步:
1. 定义train_root和test_root,存放训练数据和测试数据的路经。
2. 定义transform,确定要对数据进行什么样的预处理操作。
3. 用ImageFolder把数据加载到数据加载器,再用dataloader把数据加载到数据迭代器。
代码如下:
import torch
import torchvision
from torchvision import transforms
train_root ="C:\\Users\\X15\\Desktop\\abc_recog\\save_abc\\train"
test_root = "C:\\Users\\X15\\Desktop\\abc_recog\\save_abc\\test"
# 定义变换
transform = transforms.Compose([
transforms.Grayscale(1),
transforms.Resize(40),
transforms.ToTensor(), # 将输入图像转换为张量
transforms.Normalize(mean=[0.5], std=[0.5]) # 对张量进行归一化
])
# 训练数据集
train_data = torchvision.datasets.ImageFolder(train_root, transform=transform)
train_iter = torch.utils.data.DataLoader(train_data, batch_size=20, shuffle=True, num_workers=0)
2. 标签在图片名上
加载这种类型的数据需要继承Dataset类,对Dataset类的继承可以分为三步:
- init:初始化
- len:返回整个数据集大小
- getitem:根据索引返回图像及标签
代码如下:
import os
from torch.utils.data import DataLoader,Dataset
from torchvision import transforms
from PIL import Image
# 继承dataset
class Mydataset(Dataset):
# step1 初始化
def __init__(self, path_dir, transform=None): #变量视情况添加,至少有文件路径和tranform
self.path_dir = path_dir
self.transform = transform
self.data = os.listdir(self.path_dir)
# step2 len
def __len__(self): # 一般不用更改
return len(self.data)
# step3 getitem
def __getitem__(self, index): #重写的核心
# 加载文件目录下的文件和文件名
data_name = self.data[index] # 获取该索引的图像名,包含文件扩展名
data_path = os.path.join(self.path_dir, data_name)
image = Image.open(data_path).convert('RGB')
# 处理文件名得到label
label_raw = data_name.split('_')[0] # 我的数名称是a_xx.jpg,因此从_分割
label = ord(label_raw) - ord('a') # ord用于转换为对应ascll码,实现小写字母到0-25的映射
# 对输入进行预处理操作
if self.transform is not None:
image = self.transform(image)
return image, label # 训练数据集
# 使用
train_data = Mydataset(train_root, transform)
train_iter = torch.utils.data.DataLoader(train_data, batch_size=20, shuffle=True, num_workers=0)
这一类数据载入方法是最灵活、适用范围最广的,标签是另一个文件夹里的图像也可以使用重写dataset类的方法载入,getitem里的处理是这种方法的灵魂。
3. CSV文件存储标签
test:测试集
train:训练集
label:同时存有训练集标签和对应文件名的csv文件
这一类数据集的处理方法是,通过数据拆分程序把数据集拆分成训练集和验证集,并把训练集里的数据分门别类存入名为对应label的文件夹里,剩下的步骤与1相同,使用imagefolder载入。
代码示例:
import math
import os
import shutil
from collections import Counter
# 会用到的变量
data_dir = "F:\\jupyter\\pytorch\\data\\dog-breed-identification"#数据集的根目录
label_file = 'labels.csv'#根目录中csv的文件名加后缀
train_dir = 'train'#根目录中的训练集文件夹的名字
test_dir = 'test'#根目录中的测试集文件夹的名字
input_dir = 'train_valid_test'#用于存放拆分数据集的文件夹的名字,可以不用先创建,会自动创建
batch_size = 4#送往训练的一批次中的数据集的个数
valid_ratio = 0.1#将训练集拆分为90%为训练集10%为验证集
# 拆分程序
def reorg_dog_data(data_dir, label_file, train_dir, test_dir, input_dir,
valid_ratio):
# 读取训练数据标签,label.csv文件读取标签以及对应的文件名。
with open(os.path.join(data_dir, label_file), 'r') as f:
# 跳过文件头行(栏名称)。
lines = f.readlines()[1:]
tokens = [l.rstrip().split(',') for l in lines]
idx_label = dict(((idx, label) for idx, label in tokens))
labels = set(idx_label.values())
num_train = len(os.listdir(os.path.join(data_dir, train_dir)))#获取训练集的数量便于数据集的分割
# 训练集中数量最少一类的狗的数量。
min_num_train_per_label = (
Counter(idx_label.values()).most_common()[:-2:-1][0][1])
# 验证集中每类狗的数量。
num_valid_per_label = math.floor(min_num_train_per_label * valid_ratio)
label_count = dict()
def mkdir_if_not_exist(path):#判断是否有存放拆分后数据集的文件夹,没有就创建一个
if not os.path.exists(os.path.join(*path)):
os.makedirs(os.path.join(*path))
# 整理训练和验证集,将数据集进行拆分复制到预先设置好的存放文件夹中。
for train_file in os.listdir(os.path.join(data_dir, train_dir)):
idx = train_file.split('.')[0]
label = idx_label[idx]
mkdir_if_not_exist([data_dir, input_dir, 'train_valid', label])
shutil.copy(os.path.join(data_dir, train_dir, train_file),
os.path.join(data_dir, input_dir, 'train_valid', label))
if label not in label_count or label_count[label] < num_valid_per_label:
mkdir_if_not_exist([data_dir, input_dir, 'valid', label])
shutil.copy(os.path.join(data_dir, train_dir, train_file),
os.path.join(data_dir, input_dir, 'valid', label))
label_count[label] = label_count.get(label, 0) + 1
else:
mkdir_if_not_exist([data_dir, input_dir, 'train', label])
shutil.copy(os.path.join(data_dir, train_dir, train_file),
os.path.join(data_dir, input_dir, 'train', label))
# 整理测试集,将测试集复制存放在新建路径下的unknown文件夹中。
mkdir_if_not_exist([data_dir, input_dir, 'test', 'unknown'])
for test_file in os.listdir(os.path.join(data_dir, test_dir)):
shutil.copy(os.path.join(data_dir, test_dir, test_file),
os.path.join(data_dir, input_dir, 'test', 'unknown'))
# 拆分
#载入数据,进行数据的拆分
reorg_dog_data(data_dir, label_file, train_dir, test_dir, input_dir,valid_ratio)
# 数据加载部分可参考第一节
[参考文档]:
GitHub - MRZHANG-1997/Python