pytorch学习-数据处理生成batch

  1. 加载数据并生成batch数据
  2. 数据预处理
  3. 构建神经网络
  4. Tensor和Variable
  5. 定义loss
  6. 自动求导
  7. 优化器更新参数
  8. 训练神经网络
  9. 参数_定义
  10. 参数_初始化
  11. 如何在训练时固定一些层?
  12. 绘制loss和accuracy曲线
  13. torch.nn.Container和torch.nn.Module
  14. 各层参数及激活值的可视化
  15. 保存训练好的模型
  16. 如何加载预训练模型
  17. 如何使用cuda进行训练

读取数据生成并构建Dataset子类

假设现在已经实现从数据文件中读取输入images和标记labels(列表),那么怎么根据images和labels定义自己的数据集类?答案是作为torch.utils.data.Dataset的子类。

torchvision.datasets中有几个已经定义好的数据集类,这些类都是torch.utils.data.Dataset抽象类的子类:

在定义torch.utils.data.Dataset的子类时,必须重载的两个函数是__len__和__getitem__。__len__返回数据集的大小,__getitem__实现数据集的下标索引,返回对应的图像和标记(不一定非得返回图像和标记,返回元组的长度可以是任意长,这由网络需要的数据决定)。
在创建DataLoader时会判断__getitem__返回值的数据类型,然后用不同的if/else分支把数据转换成tensor,所以,_getitem_返回值的数据类型可选择范围很多,一种可以选择的数据类型是:图像为numpy.array,标记为int数据类型。
这里写图片描述
示例:

from __future__ import print_function
import torch.utils.data as data
import torch

class MyDataset(data.Dataset):
    def __init__(self, images, labels):
        self.images = images
        self.labels = labels

    def __getitem__(self, index):#返回的是tensor
        img, target = self.images[index], self.labels[index]
        return img, target

    def __len__(self):
        return len(self.images)

dataset = MyDataset(images, labels)
 
 

    生成batch数据

    现在有了由数据文件生成的结构数据MyDataset,那么怎么在训练时提供batch数据呢?PyTorch提供了生成batch数据的类。

    PyTorch用类torch.utils.data.DataLoader加载数据,并对数据进行采样,生成batch迭代器。

    class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False)

    参数
    dataset:Dataset类型,从其中加载数据
    batch_size:int,可选。每个batch加载多少样本
    shuffle:bool,可选。为True时表示每个epoch都对数据进行洗牌,对每一个sample进行洗牌,一个batch包含多个sample,一个sample中可能包含多张图片,一个sample就是一个iterm,可以包含一张图片也可以包含多个图片
    sampler:Sampler,可选。从数据集中采样样本的方法。
    num_workers:int,可选。加载数据时使用多少子进程。默认值为0,表示在主进程中加载数据。
    collate_fn:callable,可选。
    pin_memory:bool,可选
    drop_last:bool,可选。True表示如果最后剩下不完全的batch,丢弃。False表示不丢弃。

    示例

    kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
    train_loader = torch.utils.data.DataLoader(
        MyDataset(images, labels), batch_size=args.batch_size, shuffle=True, **kwargs)
     
     

      其他用法
      len(train_loader) :返回的是len(dataset)/batch_size

      ### 回答1: 在使用PyTorch进行模型训练时,数据不一定会完全整除batch size,即训练集中的样本数量不能被batch size整除得到一个整数结果。这种情况在实际应用中很常见,并且PyTorch提供了一些处理方法来处理这种情况。 第一种方法是将丢失的不足一个batch size的数据丢弃,这种方法简单直接,但会导致数据的浪费。这种做法适用于样本数量很大,略微丢失一部分数据不会对训练结果产生显著影响的情况。 第二种方法是通过在数据集中添加额外的样本,使得总样本数量能够整除batch size。这种方法可以使用一些数据增强技术,如图像翻转、旋转、缩放等,生成一些与原始样本类似但不完全相同的样本。这样可以保证所有样本都被用于训练,并且不会出现数据浪费的情况。 第三种方法是使用PyTorch的sampler,例如RandomSampler或SequentialSampler,来处理数据不整除batch size的情况。这些sampler可以控制数据加载的顺序和方式,确保每个batch的大小符合要求,即使总样本数量不能被batch size整除。 总之,对于数据不整除batch size的情况,我们可以通过丢弃部分数据、添加额外的样本或使用sampler等方法来处理。具体选择哪种方法取决于实际问题的特点和数据集的规模。 ### 回答2: 当pytorch训练时数据不整除batch size时,会出现最后一个batch大小小于设定的batch size的情况。在处理这个问题时,可以使用以下两种方法: 1. 丢弃余下的数据:一种简单的处理方式是丢弃余下的数据,确保所有的batch大小一致。如果数据集的大小不能被batch size整除,最后一个batch中剩余的数据会被丢弃。这种方法的好处是代码实现简单,但可能会浪费一些数据。 2. 动态调整batch大小:另一种处理方式是动态调整最后一个batch的大小,使其能够包含剩余的数据。例如,可以根据数据集的大小,将最后一个batch size设置为能够包含剩余数据的最小值,而其他batch size保持不变。这种方法需要一些额外的计算去确定最后一个batch的大小,但确保了所有的数据都能够被使用。 无论采用哪种方法,需要注意的是,在数据不整除batch size的情况下,最后一个batch的大小会发生变化,可能会对模型的训练结果产生一些影响。因此,在使用这些方法时,需要进行相关的实验和评估,确保模型的性能和效果仍然能够达到预期。 ### 回答3: 当使用PyTorch训练时,数据不整除批次大小是一个常见的情况。在这种情况下,可能会有一个或多个训练示例无法放入一个批次中,因为它们的数量不能被批次大小整除。 这种情况下,PyTorch通常有两种处理方式: 1. 去掉无法放入批次中的示例:在训练过程中,可以选择丢弃无法放入批次中的那些训练示例。这种情况下,相当于忽略了这些示例的训练,可能会导致训练数据的损失一定的准确性,但也能够保证批次训练的正常进行。 2. 动态调整批次大小:另一种处理方式是在训练过程中动态调整批次大小,以确保所有训练示例都能够得到使用。这意味着在每个批次中,最后一个没有填满的位置将留空或使用不足一个批次大小的示例数量。这种方法保证了所有示例都能够被用于训练,但可能会带来一些计算上的额外开销,因为每个批次的大小可能是不统一的。 总之,当训练时数据不整除批次大小时,可以选择去掉无法放入批次的示例或动态调整批次大小。具体使用哪种处理方法取决于情境和需求。
      评论
      添加红包

      请填写红包祝福语或标题

      红包个数最小为10个

      红包金额最低5元

      当前余额3.43前往充值 >
      需支付:10.00
      成就一亿技术人!
      领取后你会自动成为博主和红包主的粉丝 规则
      hope_wisdom
      发出的红包
      实付
      使用余额支付
      点击重新获取
      扫码支付
      钱包余额 0

      抵扣说明:

      1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
      2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

      余额充值