Pytorch实现预先处理数据读取

图像分类数据集—关于如何预先处理数据读取问题
  • MNIST数据集是图像分类中广泛使用的数据集之一,但是作为基准数据集过于简单。我们采用更加复杂的Fashion-MNIST数据集
    !pip install d2l
    %matplotlib inline
    import torch
    import torchvision
    from torch.utils import data
    from torchvision import transforms # 数据转换的库
    from d2l import torch as d2l
    d2l.use_svg_display() # 提高清晰度 
    
  • 通过框架中的内置函数将Fashion-MNIST数据集下载并读取到内存中
    # 通过ToTensor实例将图像数据从PIL类型变换为32位浮点数格式
    # 并除以255使得所有像素的数值均在0到1之间
    trans = transforms.ToTensor() # 定义转换器
    mnist_strain = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)
    # 将该数据集下载到上级目录点data下面../data,train=True是一个训练集,transform=trans得到的是转换后的张量而非图片
    mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)
    len(mnist_strain), len(mnist_test)
    mnist_train[0][0].shape # [0][0]第一张图片的张量[0][1]第一张图片的标签
    # In general, mnist_strain[i][0] refers to the i-th image in the training dataset, 
    # and mnist_strain[i][1] refers to the label of the i-th image.
    # 我们可以通过方括号[]来访问任意一个样本,下面获取第一个样本的图像和标签。
    # feature, label = mnist_train[0] 
    
    • 补充知识点
    • 在使用 PyTorch 进行机器学习时,通常需要将输入数据转换为张量的形式,因为PyTorch中的大多数模型和函数都是针对张量进行操作的。transforms.ToTensor() 就是用于将输入数据转换为张量的一个方便的工具。使用方法如下:
    import torch
    from torchvision import transforms
    # 定义转换器
    to_tensor = transforms.ToTensor()
    # 将输入数据转换为张量
    inputs = to_tensor(inputs)
    
    • 在上面的代码中,inputs 是输入数据,它可以是一个 NumPy 数组、PIL 图像或者其他形式的数据。通过调用 to_tensor 对象的 call 方法,就可以将 inputs 转换为张量的形式。
  • 两个可视化数据集的函数
    # 以下函数可以将数值标签转成相应的文本标签。
    def get_fashion_mnist_labels(labels):
        text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
                      'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
        return [text_labels[int(i)] for i in labels]
        # return的逻辑就是i是逐个被赋予labels的值,而后通过i作为索引查找对应的标签
        # 最终返回一个新的列表,即位labels所对应的文本信息
    """
    这个函数定义了一个有十个元素的列表 text_labels,这十个元素分别表示十种服装的名称。然后,它使用了一个列表推导式来创建并返回一个新列表。这个列表推导式包含一个循环,循环变量名为 i,循环体为 text_labels[int(i)]。这个列表推导式会对 labels 中的每个元素执行一次循环,并将每次循环的结果添加到新列表中。
    """
    # 下面定义一个可以在一行里画出多张图像和对应标签的函数。
    def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5):  
      """Plot a list of images."""
      figsize = (num_cols * scale, num_rows * scale) # 适当的放大比例,方便后续观看
      _, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)
      # subplots返回两个参数,第一个是图片总数,第二个是一个二维数组,要访问其需要通过二维数组的形式
      axes = axes.flatten()
      # 扁平化处理将二维子图像数组转化为一维度,方便后续遍历
      for i, (ax, img) in enumerate(zip(axes, imgs)):
        # zip函数用于将axes和imgs按照顺序每一对绑定后返回一个列表[(1, 'a'), (2, 'b'), (3, 'c')]
        # 这样我们可以用索引访问
        if torch.is_tensor(img):
          ax.imshow(img.numpy()) # 是张量时候的画图方式
        else:
          ax.imshow(img) # 是PIL的画图格式
        # 这一句if其实是判断是否为ax是否为tensor张量,ax.的意思就是在子图中表示img这个图片
        ax.axes.get_xaxis().set_visible(False) # 不要X和Y轴
        ax.axes.get_yaxis().set_visible(False)
        if titles:
          ax.set_title(titles[i]) # 如果有titles则按照顺序表示在对应位置
      return axes
    
  • 小批度读取数据并同时进行读取速度测试
    batch_size = 256
    def get_dataloader_workers():  
        """使用4个进程来读取数据。"""
        return 2
    train_iter = data.DataLoader(mnist_train, batch_size, shuffle=True,
                                num_workers=get_dataloader_workers())
    # 几个进程取数据num_workers
    timer = d2l.Timer()
    for X, y in train_iter:
        continue
    f'{timer.stop():.2f} sec'
    # 测试数据读取的速度,避免出现训练快数据读不上
  • 将数据读取部分进行封装成为一个函数
      def load_data_fashion_mnist(batch_size, resize=None):  
          """下载Fashion-MNIST数据集,然后将其加载到内存中。"""
          trans = [transforms.ToTensor()]
          if resize:
              trans.insert(0, transforms.Resize(resize))
          trans = transforms.Compose(trans)
          mnist_train = torchvision.datasets.FashionMNIST(root="../data",
                                                          train=True,
                                                          transform=trans,
                                                          download=True)
          mnist_test = torchvision.datasets.FashionMNIST(root="../data",
                                                        train=False,
                                                        transform=trans,
                                                        download=True)
          return (data.DataLoader(mnist_train, batch_size, shuffle=True,
                                  num_workers=get_dataloader_workers()),
                  data.DataLoader(mnist_test, batch_size, shuffle=False,
                                  num_workers=get_dataloader_workers()))
      # 封装成一个读取函数
      train_iter, test_iter = load_data_fashion_mnist(32, resize=64)
      for X, y in train_iter:
          print(X.shape, X.dtype, y.shape, y.dtype)
          break
    
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值