读取blob_pytorch中继承Dataset自定义数据读取流

7c7d3cc2ad99f3e62e399c3c85f80431.png

pytorch虽然简单易用,但是其高度的封装使得初学者难以理解数据是如何读入的。对于自己的任务,很可能pytorch提供的数据读取机制难以完全满足任务要求,所以我们需要学习如何使用pytorch提供的torch.utils.data.Dataset来自定义数据读取流程(文末附完整代码)。下面来分析一下Dataset类的源码【1】:

class Dataset(object):
    """此处省略"""
    def __getitem__(self, index):
        raise NotImplementedError

    def __add__(self, other):
        return ConcatDataset([self, other])

也就是说我们只需要实现“__getitem__(self, index)”方法即可,这个__getitem__()方法的index参数看起来有些让人困惑,查阅python官方文档【2】发现这个方法是object的方法:

8bdb530b6b8bd490223905d7f04632c6.png

所以这个index的值就是索引值,比如我们想要数据集中第二张图片索引就是2,关于__getitem__()的具体的介绍可以看【3】。

一、自定义SingeClassDataset

我在自定义Dataset时共分为3个步骤,目的是以后能够好的进行功能扩展:

  • 初始化图像路径模块__init__()
  • 图像转Tensor模块_read_convert_image()
  • 数据索引模块__getitem__()

由于目的是继承Dataset类,所以应该采用__init__()来存储数据路径,在存储之前要先检查输入的路径是不是正确的路径,防止图片读取失败:

def __init__(self, file_path):
    # 保证输入的是正确的路径
    if not os.path.isdir(file_path):
        raise ValueError("input file_path is not a dir")
    self.file_path = file_path
    # 获取路径下所有的图片名称,必须保证路径内没有图片以外的数据
    self.image_list = os.listdir(file_path)
    # 将PIL的Image转为Tensor
    self.transforms = T.ToTensor()

读取图像采用python内置的PIL提供的Image类型,这也是pytorch支持的核心类型。读取Image类型的图片后可以直接通过torchvision提供的变换,转为pytorch需要的Tensor类型:

def _read_convert_image(self, image_name):
    image = Image.open(image_name)
    image = self.transforms(image).float()
    return image

拥有上述两个方法以后,就可以实现完整的数据读取了。根据图像的存储方式不同可以采用多种读取策略,常见的情况有两种:图像在一个文件夹中、图像在多个文件夹中。下面实现的__getitem__()方法针对于所有的图像在一个文件夹内的情况:

def __getitem__(self, index):
    # 根据index获取图片完整路径
    image_path = os.path.join(self.file_path, self.image_list[index])
    # 都图片并转为Tensor
    image = self._read_convert_image(image_path)
    return image

二、测试自定义的SingleClassDataset

测试之前首先准备好数据放到“data”文件夹中,如下:

60653dfb1fe66eb22864166b7696d757.png

定义了__getitem__()方法的类就可以通过“索引”获得数据,下面来看一下数据是正确的读入了,可视化采用的是matplotlib,下面的代码展示了如何可视化前16张图片:

import matplotlib.pyplot as plt
MyDataset = SingleClassDataset(file_path="data/")
plt.figure()
for i in range(16):
    plt.subplot(4, 4, i+1)
    plt.imshow(MyDataset[i].numpy().transpose(1, 2, 0))
plt.show()

得到了下面的结果,就说明数据读取是没问题的:

623e9f8350f39bdec4e4c09bf03f952f.png

三、完整代码

由于之前pytorch的版本还必须实现"__len__()"方法用于返回数据集的长度,所以下面的代码实现了它,但是当前版本的pytorch已经不再强制实现这个函数了,整体代码如下:

from torch.utils.data import Dataset
import os
from PIL import Image
import torchvision.transforms as T


class SingleClassDataset(Dataset):
    """
    This Dataset only work for a folder that contains one class image!!!
    """

    def __init__(self, file_path):
        # 保证输入的是正确的路径
        if not os.path.isdir(file_path):
            raise ValueError("input file_path is not a dir")
        self.file_path = file_path
        # 获取路径下所有的图片名称,必须保证路径内没有图片以外的数据
        self.image_list = os.listdir(file_path)
        # 将PIL的Image转为Tensor
        self.transforms = T.ToTensor()

    def __getitem__(self, index):
        # 根据index获取图片完整路径
        image_path = os.path.join(self.file_path, self.image_list[index])
        # 都图片并转为Tensor
        image = self._read_convert_image(image_path)
        return image

    def _read_convert_image(self, image_name):
        image = Image.open(image_name)
        image = self.transforms(image).float()
        return image

    def __len__(self):
        return len(self.image_list)


import matplotlib.pyplot as plt
MyDataset = SingleClassDataset(file_path="data/")
plt.figure()
for i in range(16):
    plt.subplot(4, 4, i+1)
    plt.imshow(MyDataset[i].numpy().transpose(1, 2, 0))

plt.show()

参考:

【1】https://github.com/pytorch/pytorch/blob/master/torch/utils/data/dataset.py

【2】https://docs.python.org/3/reference/datamodel.html#object.__getitem__

【3】https://zhuanlan.zhihu.com/p/87786297

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
Go语言(也称为Golang)是由Google开发的一种静态强类型、编译型的编程语言。它旨在成为一门简单、高效、安全和并发的编程语言,特别适用于构建高性能的服务器和分布式系统。以下是Go语言的一些主要特点和优势: 简洁性:Go语言的语法简单直观,易于学习和使用。它避免了复杂的语法特性,如继承、重载等,转而采用组合和接口来实现代码的复用和扩展。 高性能:Go语言具有出色的性能,可以媲美C和C++。它使用静态类型系统和编译型语言的优势,能够生成高效的机器码。 并发性:Go语言内置了对并发的支持,通过轻量级的goroutine和channel机制,可以轻松实现并发编程。这使得Go语言在构建高性能的服务器和分布式系统时具有天然的优势。 安全性:Go语言具有强大的类型系统和内存管理机制,能够减少运行时错误和内存泄漏等问题。它还支持编译时检查,可以在编译阶段就发现潜在的问题。 标准库:Go语言的标准库非常丰富,包含了大量的实用功能和工具,如网络编程、文件操作、加密解密等。这使得开发者可以更加专注于业务逻辑的实现,而无需花费太多时间在底层功能的实现上。 跨平台:Go语言支持多种操作系统和平台,包括Windows、Linux、macOS等。它使用统一的构建系统(如Go Modules),可以轻松地跨平台编译和运行代码。 开源和社区支持:Go语言是开源的,具有庞大的社区支持和丰富的资源。开发者可以通过社区获取帮助、分享经验和学习资料。 总之,Go语言是一种简单、高效、安全、并发的编程语言,特别适用于构建高性能的服务器和分布式系统。如果你正在寻找一种易于学习和使用的编程语言,并且需要处理大量的并发请求和数据,那么Go语言可能是一个不错的选择。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值