【Pytorch学习笔记】数据模块01——Dataset和Dateloader

Pytorch数据模块两大核心:Dataset和Dataloader,下面开始介绍基本功能:

Dataset:

将数据读取,并预处理好,再提供给dataloader

必须实现的两个函数如下:

getitem:

1,读取数据;2,数据预处理;3,返回数据(通常是输入和标签)

示例:

def __getitem__(self, index):    
*"""    输入标量index, 从硬盘中读取数据,并预处理,to Tensor    
:param index:    
:return:    """*    
	path_img, label = self.img_info[index]    
	img = Image.open(path_img).convert('L')    
	if self.transform is not None:        
		img = self.transform(img)    
	
	return img, label

这段代码是Dataset类中的__getitem__方法的实现,解释其主要功能:

  1. 输入参数index用于索引指定的数据样本
  2. 主要执行三个步骤:
  • 从self.img_info中获取图片路径和标签
  • 使用PIL库打开图片并转换为灰度图('L’模式)
  • 如果存在transform预处理,则对图片进行变换处理
  1. 最终返回处理后的图片数据和对应的标签

这个方法是Dataset类的核心方法之一,它使得数据集能够通过索引的方式获取单个样本,这对于后续使用DataLoader进行批量数据加载非常重要

len:

返回数据集大小

示例:

def __len__(self):    
	if len(self.img_info) == 0:        
		raise Exception("\ndata_dir:{} is a empty dir! Please checkout your path to images!".format(self.root_dir))  *# 代码具有友好的提示功能,便于debug*    
	return len(self.img_info)

使用一个函数来建立磁盘关系:_get_img_info

收集和处理样本的路径信息和标签信息,存储到一个list中,供getitem使用。

示例:

def _get_img_info(self):
    """
    实现数据集的读取,将硬盘中的数据路径,标签读取进来,存在一个list中
    path, label
    :return:
    """
    df = pd.read_csv(self.path_csv)
    df.drop(df[df["set-type"] != self.mode].index, inplace=True)  # 只要对应的数据
    df.reset_index(inplace=True)    # 非常重要! pandas的drop不会改变index
    # 遍历表格,获取每张样本信息
    for idx in range(len(df)):
        path_img = os.path.join(self.root_dir, df.loc[idx, "img-name"])
        label_int = int(df.loc[idx, "label"])
        self.img_info.append((path_img, label_int))

这段代码是_get_img_info方法的实现,它的主要功能是读取和组织数据集的信息。详细解释这段代码的工作流程:

  1. 读取CSV文件:使用pandas读取包含数据集信息的CSV文件
  2. 数据筛选:
  • 根据"set-type"字段筛选出符合指定mode的数据
  • 重置索引(这一步很重要,因为pandas的drop操作不会自动重置索引)
  1. 数据收集:遍历筛选后的数据框,对每个样本:
  • 构建完整的图片路径(结合root_dir和图片名)
  • 获取对应的标签
  • 将(图片路径, 标签)的元组添加到self.img_info列表中

这个方法与__getitem__的关系:

当__getitem__被调用时,它会使用index参数从self.img_info列表中获取对应的图片路径和标签。这样就建立了数据在硬盘中的位置和数据集索引之间的映射关系,使得数据集能够通过索引的方式获取单个样本。

Dataloader:

五大功能:

  1. 支持两种形式数据集读取
  2. 自定义采样策略
  3. 自动组装成批数据
  4. 多进程数据加载
  5. 自动实现锁页内存

DataLoader 常用API

  • dataset:一个Dataset的实例,要能实现索引到样本的映射。
  • batch_size:每个batch的样本量
  • shuffle:是否打乱样本顺序
  • sampler:设置采样策略
  • batch_sampler:设置采样策略
  • num_workers:设置多少个子进程进行数据加载
  • collate_fn:组装数据的规则,决定如何将一批数据组装起来
  • pin_memory:是否使用锁页内存
  • drop_last:每个epoch是否放弃最后一批不足batch_size大小的数据。(不存在舍弃数据的情况,因为shuffle会打乱样本顺序)

详解batch_size参数

batch_size是深度学习中一个重要的超参数,它定义了每次投入模型训练的样本数量。

batch_size的作用:

  • **内存管理:**较小的batch_size需要更少的内存,适合在内存受限的情况下使用
  • **训练速度:**较大的batch_size通常能提高GPU利用率,加快训练速度
  • **模型性能:**batch_size会影响模型的泛化能力和收敛性能

选择合适的batch_size:

  1. 考虑硬件限制:根据GPU显存大小选择合适的batch_size
  2. 训练稳定性:较小的batch_size通常能提供更好的泛化性能
  3. 训练效率:较大的batch_size能提高训练速度,但可能影响模型收敛

常见batch_size选择:

  • 对于一般的深度学习任务:32、64、128是常用值
  • 对于大规模数据集:可以选择256、512等更大的值
  • 对于小数据集:可以选择8、16等较小的值

transform数据处理示例解析:

normalize = transforms.Normalize([0.5], [0.5])
    transforms_train = transforms.Compose([
        transforms.Resize((4, 4)),
        transforms.ToTensor(),
        normalize
    ])

这段代码展示了图像预处理的常用变换操作:

  • transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]):对RGB三个通道分别进行标准化,第一个列表是均值,第二个是标准差。这些值来自ImageNet数据集的统计结果,使用它们可以让模型更容易收敛。
  • transforms.Resize((224, 224)):将输入图片调整为224x224的尺寸,这是许多预训练模型的标准输入大小。
  • transforms.ToTensor():将PIL图像或NumPy数组转换为张量,同时将像素值范围从[0, 255]缩放到[0.0, 1.0]。
  • transforms.Compose:将多个transform操作组合成一个transform序列,按照列表顺序依次执行。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

越轨

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值