Pytorch入门之——数据加载总结


使用torch将数据加载到数据加载器中可以分为以下三类:
1.标签在文件夹上,可以通过ImageFolder载入。
2.标签在图片名上,可以通过继承dataset类载入。
3.标签是CSV文件,可以通过合适的文件拆分函数,拆成第一种/第二中情况。
ImageFolder和Dataset都可以把数据转换为Dataset对象,而DataLoader的作用则是把Dataset对象转换为mini-batch数据并提供迭代器。

1. 标签在文件夹上

在这里插入图片描述
这种类型的数据可以调用ImageFolder来载入,步骤可以分为以下三步:
1. 定义train_root和test_root,存放训练数据和测试数据的路经。
2. 定义transform,确定要对数据进行什么样的预处理操作。
3. 用ImageFolder把数据加载到数据加载器,再用dataloader把数据加载到数据迭代器。
代码如下:

import torch
import torchvision
from torchvision import transforms

train_root ="C:\\Users\\X15\\Desktop\\abc_recog\\save_abc\\train"
test_root = "C:\\Users\\X15\\Desktop\\abc_recog\\save_abc\\test"

# 定义变换
transform = transforms.Compose([
    transforms.Grayscale(1),
    transforms.Resize(40),
    transforms.ToTensor(),  # 将输入图像转换为张量
    transforms.Normalize(mean=[0.5], std=[0.5])  # 对张量进行归一化
])

# 训练数据集
train_data = torchvision.datasets.ImageFolder(train_root, transform=transform)
train_iter = torch.utils.data.DataLoader(train_data, batch_size=20, shuffle=True, num_workers=0)

2. 标签在图片名上

在这里插入图片描述

加载这种类型的数据需要继承Dataset类,对Dataset类的继承可以分为三步:

  1. init:初始化
  2. len:返回整个数据集大小
  3. getitem:根据索引返回图像及标签
    代码如下:
import os
from torch.utils.data import DataLoader,Dataset
from torchvision import transforms
from PIL import Image

# 继承dataset
class Mydataset(Dataset):
    
    # step1 初始化
    def __init__(self, path_dir, transform=None): #变量视情况添加,至少有文件路径和tranform
        self.path_dir = path_dir
        self.transform = transform
        self.data = os.listdir(self.path_dir)
        
    # step2 len
    def __len__(self): # 一般不用更改
        return len(self.data)
    
    # step3 getitem
    def __getitem__(self, index): #重写的核心
        
        # 加载文件目录下的文件和文件名
        data_name = self.data[index] # 获取该索引的图像名,包含文件扩展名
        data_path = os.path.join(self.path_dir, data_name)
        image = Image.open(data_path).convert('RGB')
        
        # 处理文件名得到label
        label_raw = data_name.split('_')[0]  # 我的数名称是a_xx.jpg,因此从_分割
        label = ord(label_raw) - ord('a') # ord用于转换为对应ascll码,实现小写字母到0-25的映射
        
        # 对输入进行预处理操作
        if self.transform is not None:
            image = self.transform(image)
            
        return image, label # 训练数据集

# 使用

train_data = Mydataset(train_root, transform)

train_iter = torch.utils.data.DataLoader(train_data, batch_size=20, shuffle=True, num_workers=0)

这一类数据载入方法是最灵活、适用范围最广的,标签是另一个文件夹里的图像也可以使用重写dataset类的方法载入,getitem里的处理是这种方法的灵魂。

3. CSV文件存储标签

在这里插入图片描述
test:测试集
train:训练集
label:同时存有训练集标签和对应文件名的csv文件
label.csv
这一类数据集的处理方法是,通过数据拆分程序把数据集拆分成训练集和验证集,并把训练集里的数据分门别类存入名为对应label的文件夹里,剩下的步骤与1相同,使用imagefolder载入。
在这里插入图片描述
代码示例:

import math
import os
import shutil
from collections import Counter

# 会用到的变量
data_dir = "F:\\jupyter\\pytorch\\data\\dog-breed-identification"#数据集的根目录
label_file = 'labels.csv'#根目录中csv的文件名加后缀
train_dir = 'train'#根目录中的训练集文件夹的名字
test_dir = 'test'#根目录中的测试集文件夹的名字
input_dir = 'train_valid_test'#用于存放拆分数据集的文件夹的名字,可以不用先创建,会自动创建
batch_size = 4#送往训练的一批次中的数据集的个数
valid_ratio = 0.1#将训练集拆分为90%为训练集10%为验证集

# 拆分程序
def reorg_dog_data(data_dir, label_file, train_dir, test_dir, input_dir,
                   valid_ratio):
    # 读取训练数据标签,label.csv文件读取标签以及对应的文件名。
    with open(os.path.join(data_dir, label_file), 'r') as f:
        # 跳过文件头行(栏名称)。
        lines = f.readlines()[1:]
        tokens = [l.rstrip().split(',') for l in lines]
        idx_label = dict(((idx, label) for idx, label in tokens))
    labels = set(idx_label.values())

    num_train = len(os.listdir(os.path.join(data_dir, train_dir)))#获取训练集的数量便于数据集的分割
    # 训练集中数量最少一类的狗的数量。
    min_num_train_per_label = (
        Counter(idx_label.values()).most_common()[:-2:-1][0][1])
    # 验证集中每类狗的数量。
    num_valid_per_label = math.floor(min_num_train_per_label * valid_ratio)
    label_count = dict()

    def mkdir_if_not_exist(path):#判断是否有存放拆分后数据集的文件夹,没有就创建一个
        if not os.path.exists(os.path.join(*path)):
            os.makedirs(os.path.join(*path))

    # 整理训练和验证集,将数据集进行拆分复制到预先设置好的存放文件夹中。
    for train_file in os.listdir(os.path.join(data_dir, train_dir)):
        idx = train_file.split('.')[0]
        label = idx_label[idx]
        mkdir_if_not_exist([data_dir, input_dir, 'train_valid', label])
        shutil.copy(os.path.join(data_dir, train_dir, train_file),
                    os.path.join(data_dir, input_dir, 'train_valid', label))
        if label not in label_count or label_count[label] < num_valid_per_label:
            mkdir_if_not_exist([data_dir, input_dir, 'valid', label])
            shutil.copy(os.path.join(data_dir, train_dir, train_file),
                        os.path.join(data_dir, input_dir, 'valid', label))
            label_count[label] = label_count.get(label, 0) + 1
        else:
            mkdir_if_not_exist([data_dir, input_dir, 'train', label])
            shutil.copy(os.path.join(data_dir, train_dir, train_file),
                        os.path.join(data_dir, input_dir, 'train', label))

    # 整理测试集,将测试集复制存放在新建路径下的unknown文件夹中。
    mkdir_if_not_exist([data_dir, input_dir, 'test', 'unknown'])
    for test_file in os.listdir(os.path.join(data_dir, test_dir)):
        shutil.copy(os.path.join(data_dir, test_dir, test_file),
                    os.path.join(data_dir, input_dir, 'test', 'unknown'))

# 拆分
#载入数据,进行数据的拆分
reorg_dog_data(data_dir, label_file, train_dir, test_dir, input_dir,valid_ratio)

# 数据加载部分可参考第一节

[参考文档]:
GitHub - MRZHANG-1997/Python

  • 5
    点赞
  • 13
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值