PyTorch~自定义数据读取

这次是PyTorch的自定义数据读取pipeline模板和相关trciks以及如何优化数据读取的pipeline等。

因为有torch也放人工智能模块了~

从PyTorch的数据对象类Dataset开始。Dataset在PyTorch中的模块位于utils.data下。

from torch.utils.data import Dataset

围绕Dataset对象分别从原始模板、torchvision的transforms模块、使用pandas来辅助读取、torch内置数据划分功能和DataLoader来展开阐述。

Dataset原始模板

PyTorch官方为我们提供了自定义数据读取的标准化代码代码模块,作为一个读取框架,我们这里称之为原始模板。其代码结构如下:

from torch.utils.data import Dataset
class CustomDataset(Dataset):
    def __init__(self, ...):
        # stuff
        
    def __getitem__(self, index):
        # stuff
        return (img, label)
        
    def __len__(self):
        # return examples size
        return count

根据这个标准化的代码模板,我们只需要根据自己的数据读取任务,分别往__init__()、__getitem__()和__len__()三个方法里添加读取逻辑即可。作为PyTorch范式下的数据读取以及为了后续的data loader,三个方法缺一不可。其中:

  • __init__()函数用于初始化数据读取逻辑,比如读取包含标签和图片地址的csv文件、定义transform组合等。

  • __getitem__()函数用来返回数据和标签。目的上是为了能够被后续的dataloader所调用。

  • __len__()函数则用于返回样本数量。

现在我们往这个框架里填几行代码来形成一个简单的数字案例。创建一个从1到100的数字例子:

from torch.utils.data import Dataset
class CustomDataset(Dataset):
    def __init__(self):
        self.samples = list(range(1, 101))
    def __len__(self):
        return len(self.samples)
    def __getitem__(self, idx):
        return self.samples[idx]
        
if __name__ == '__main__':
    dataset = CustomDataset()
    print(len(dataset))
    print(dataset[50])
    print(dataset[1:100])

 添加torchvision.transforms

然后我们来看如何从内存中读取数据以及如何在读取过程中嵌入torchvision中的transforms功能。torchvision是一个独立于torch的关于数据、模型和一些图像增强操作的辅助库。主要包括datasets默认数据集模块、models经典模型模块、transforms图像增强模块以及utils模块等。在使用torch读取数据的时候,一般会搭配上transforms模块对数据进行一些处理和增强工作。

添加了tranforms之后的读取模块可以改写为:

from torch.utils.data import Dataset
from torchvision import transforms as T

class CustomDataset(Dataset):
    def __init__(self, ...):
        # stuff
        ...
        # compose the transforms methods
        self.transform = T.Compose([T.CenterCrop(100),
                                T.ToTensor()])
        
    def __getitem__(self, index):
        # stuff
        ...
        data = # Some data read from a file or image
        # execute the transform
        data = self.transform(data)  
        return (img, label)
        
    def __len__(self):
        # return examples size
        return count
        
if __name__ == '__main__':
    # Call the dataset
    custom_dataset = CustomDataset(...)

可以看到,我们使用了Compose方法来把各种数据处理方法聚合到一起进行定义数据转换方法。通常作为初始化方法放在__init__()函数下。我们以猫狗图像数据为例进行说明。

定义数据读取方法如下:

class DogCat(Dataset):    
    def __init__(self, root, transforms=None, train=True, val=False):
        """
        get images and execute transforms.
        """
        self.val = val
        imgs = [os.path.join(root, img) for img in os.listdir(root)]
        # train: Cats_Dogs/trainset/cat.1.jpg
        # val: Cats_Dogs/valset/cat.10004.jpg
        imgs = sorted(imgs, key=lambda x: x.split('.')[-2])
        self.imgs = imgs         
        if transforms is None:
            # normalize      
            normalize = T.Normalize(mean = [0.485, 0.456, 0.406],
                                     std = [0.229, 0.224, 0.225])
            # trainset and valset have different data transform 
            # trainset need data augmentation but valset don't.
            # valset

            if self.val:
                self.transforms = T.Compose([
                    T.Resize(224),
                    T.CenterCrop(224),
                    T.ToTensor(),
                    normalize
                ])
            # trainset
            else:
                self.transforms = T.Compose([
                    T.Resize(256),
                    T.RandomResizedCrop(224),
                    T.RandomHorizontalFlip(),
                    T.ToTensor(),
                    normalize
                ])
                       
    def __getitem__(self, index):
        """
        return data and label
        """
        img_path = self.imgs[index]
        label = 1 if 'dog' in img_path.split('/')[-1] else 0
        data = Image.open(img_path)
        data = self.transforms(data)
        return data, label
  
    def __len__(self):
        """
        return images size.
        """
        return len(self.imgs)

if __name__ == "__main__":
    train_dataset = DogCat('./Cats_Dogs/trainset/', train=True)
    print(len(train_dataset))
    print(train_dataset[0])

因为这个数据集已经分好了训练集和验证集,所以在读取和transforms的时候需要进行区分。运行示例如下:

与pandas一起使用

很多时候数据的目录地址和标签都是通过csv文件给出的。如下所示:

此时在数据读取的pipeline中我们需要在__init__()方法中利用pandas把csv文件中包含的图片地址和标签融合进去。相应的数据读取pipeline模板可以改写为:

class CustomDatasetFromCSV(Dataset):
    def __init__(self, csv_path):
        """
        Args:
            csv_path (string): path to csv file
            transform: pytorch transforms for transforms and tensor conversion
        """
        # Transforms
        self.to_tensor = transforms.ToTensor()
        # Read the csv file
        self.data_info = pd.read_csv(csv_path, header=None)
        # First column contains the image paths
        self.image_arr = np.asarray(self.data_info.iloc[:, 0])
        # Second column is the labels
        self.label_arr = np.asarray(self.data_info.iloc[:, 1])
        # Calculate len
        self.data_len = len(self.data_info.index)

    def __getitem__(self, index):
        # Get image name from the pandas df
        single_image_name = self.image_arr[index]
        # Open image
        img_as_img = Image.open(single_image_name)
        # Transform image to tensor
        img_as_tensor = self.to_tensor(img_as_img)
        # Get label of the image based on the cropped pandas column
        single_image_label = self.label_arr[index]
        return (img_as_tensor, single_image_label)

    def __len__(self):
        return self.data_len

if __name__ == "__main__":
    # Call dataset
    dataset =  CustomDatasetFromCSV('./labels.csv')

以mnist_label.csv文件为示例:

from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import transforms as T
from PIL import Image
import os
import numpy as np
import pandas as pd

class CustomDatasetFromCSV(Dataset):
    def __init__(self, csv_path):
        """
        Args:
            csv_path (string): path to csv file            
            transform: pytorch transforms for transforms and tensor conversion
        """
        # Transforms
        self.to_tensor = T.ToTensor()
        # Read the csv file
        self.data_info = pd.read_csv(csv_path, header=None)
        # First column contains the image paths
        self.image_arr = np.asarray(self.data_info.iloc[:, 0])
        # Second column is the labels
        self.label_arr = np.asarray(self.data_info.iloc[:, 1])
        # Third column is for an operation indicator
        self.operation_arr = np.asarray(self.data_info.iloc[:, 2])
        # Calculate len
        self.data_len = len(self.data_info.index)

    def __getitem__(self, index):
        # Get image name from the pandas df
        single_image_name = self.image_arr[index]
        # Open image
        img_as_img = Image.open(single_image_name)
        # Check if there is an operation
        some_operation = self.operation_arr[index]
        # If there is an operation
        if some_operation:
            # Do some operation on image
            # ...
            # ...
            pass

        # Transform image to tensor
        img_as_tensor = self.to_tensor(img_as_img)
        # Get label of the image based on the cropped pandas column
        single_image_label = self.label_arr[index]
        return (img_as_tensor, single_image_label)

    def __len__(self):
        return self.data_len

if __name__ == "__main__":
    transform = T.Compose([T.ToTensor()])
    dataset = CustomDatasetFromCSV('./mnist_labels.csv')
    print(len(dataset))
    print(dataset[5])

运行示例如下:

训练集验证集划分

一般来说,为了模型训练的稳定,我们需要对数据划分训练集和验证集。torch的Dataset对象也提供了random_split函数作为数据划分工具,且划分结果可直接供后续的DataLoader使用。

以kaggle的花朵数据为例:  whaosoft aiot http://143ai.com

from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision import transforms as T
from torch.utils.data import random_split

transform = T.Compose([
    T.Resize((224, 224)),
    T.RandomHorizontalFlip(),
    T.ToTensor()
 ])

dataset = ImageFolder('./flowers_photos', transform=transform)
print(dataset.class_to_idx)

trainset, valset = random_split(dataset, 
                [int(len(dataset)*0.7), len(dataset)-int(len(dataset)*0.7)])

trainloader = DataLoader(dataset=trainset, batch_size=32, shuffle=True, num_workers=1)
for i, (img, label) in enumerate(trainloader):
    img, label = img.numpy(), label.numpy()
    print(img, label)

valloader = DataLoader(dataset=valset, batch_size=32, shuffle=True, num_workers=1)
for i, (img, label) in enumerate(trainloader):
    img, label = img.numpy(), label.numpy()
    print(img.shape, label)

这里使用了ImageFolder模块,可以直接读取各标签对应的文件夹,部分运行示例如下: 

使用DataLoader

dataset方法写好之后,我们还需要使用DataLoader将其逐个喂给模型。上一节的数据划分我们已经用到了DataLoader函数。从本质上来讲,DataLoader只是调用了__getitem__()方法并按批次返回数据和标签。使用方法如下:

from torch.utils.data import DataLoader
from torchvision import transforms as T

if __name__ == "__main__":
    # Define transforms
    transformations = T.Compose([T.ToTensor()])
    # Define custom dataset
    dataset = CustomDatasetFromCSV('./labels.csv')
    # Define data loader
    data_loader = DataLoader(dataset=dataset, batch_size=10, shuffle=True)
    for images, labels in data_loader:
        # Feed the data to the model

以上就是PyTorch读取数据的Pipeline主要方法和流程。基于Dataset对象的基本框架不变,具体细节可自定义化调整。

  • 6
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 4
    评论
### 回答1: 在 PyTorch读取自定义数据集的一般步骤如下: 1. 定义数据集类:首先需要定义一个数据集类,继承自 `torch.utils.data.Dataset` 类,并实现 `__getitem__` 和 `__len__` 方法。在 `__getitem__` 方法中,根据索引返回一个样本的数据和标签。 2. 加载数据集:使用 `torch.utils.data.DataLoader` 类加载数据集,可以设置批量大小、多线程读取数据等参数。 下面是一个简单的示例代码,演示如何使用 PyTorch 读取自定义数据集: ```python import torch from torch.utils.data import Dataset, DataLoader class CustomDataset(Dataset): def __init__(self, data, targets): self.data = data self.targets = targets def __getitem__(self, index): x = self.data[index] y = self.targets[index] return x, y def __len__(self): return len(self.data) # 加载训练集和测试集 train_data = ... train_targets = ... train_dataset = CustomDataset(train_data, train_targets) train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) test_data = ... test_targets = ... test_dataset = CustomDataset(test_data, test_targets) test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False) # 训练模型 for epoch in range(num_epochs): for batch_idx, (data, targets) in enumerate(train_loader): # 前向传播、反向传播,更新参数 ... ``` 在上面的示例代码中,我们定义了一个 `CustomDataset` 类,加载了训练集和测试集,并使用 `DataLoader` 类分别对它们进行批量读取。在训练模型时,我们可以像使用 PyTorch 自带的数据集一样,循环遍历每个批次的数据和标签,进行前向传播、反向传播等操作。 ### 回答2: PyTorch是一个开源的深度学习框架,它提供了丰富的功能用于读取和处理自定义数据集。下面是一个简单的步骤来读取自定义数据集。 首先,我们需要定义一个自定义数据集类,该类应继承自`torch.utils.data.Dataset`类,并实现`__len__`和`__getitem__`方法。`__len__`方法应返回数据集的样本数量,`__getitem__`方法根据给定索引返回一个样本。 ```python import torch from torch.utils.data import Dataset class CustomDataset(Dataset): def __init__(self, data): self.data = data def __len__(self): return len(self.data) def __getitem__(self, idx): sample = self.data[idx] return torch.tensor(sample) ``` 接下来,我们可以创建一个数据集实例并传入自定义数据。假设我们有一个包含多个样本的列表 `data`。 ```python data = [[1, 2, 3], [4, 5, 6], [7, 8, 9]] dataset = CustomDataset(data) ``` 然后,我们可以使用`torch.utils.data.DataLoader`类加载数据集,并指定批次大小、是否打乱数据等。 ```python batch_size = 2 dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True) ``` 现在,我们可以迭代数据加载器来获取批次的样本。 ```python for batch in dataloader: print(batch) ``` 上面的代码将打印出两个批次的样本。如果`shuffle`参数设置为`True`,则每个批次的样本将是随机的。 总而言之,PyTorch提供了简单而强大的工具来读取和处理自定义数据集,可以根据实际情况进行适当修改和扩展。 ### 回答3: PyTorch是一个流行的深度学习框架,可以用来训练神经网络模型。要使用PyTorch读取自定义数据集,可以按照以下几个步骤进行: 1. 准备数据集:将自定义数据集组织成合适的目录结构。通常情况下,可以将数据集分为训练集、验证集和测试集,每个集合分别放在不同的文件夹中。确保每个文件夹中的数据按照类别进行分类,以便后续的标签处理。 2. 创建数据加载器:在PyTorch中,数据加载器是一个有助于有效读取和处理数据的类。可以使用`torchvision.datasets.ImageFolder`类创建一个数据加载器对象,通过传入数据集的目录路径来实现。 3. 数据预处理:在将数据传入模型之前,可能需要对数据进行一些预处理操作,例如图像变换、标准化或归一化等。可以使用`torchvision.transforms`中的类来实现这些预处理操作,然后将它们传入数据加载器中。 4. 创建数据迭代器:数据迭代器是连接数据集和模型的重要接口,它提供了一个逐批次加载数据的功能。可以使用`torch.utils.data.DataLoader`类创建数据迭代器对象,并设置一些参数,例如批量大小、是否打乱数据等。 5. 使用数据迭代器:在训练时,可以使用Python的迭代器来遍历数据集并加载数据。通常,它会在每个迭代步骤中返回一个批次的数据和标签。可以通过`for`循环来遍历数据迭代器,并在每个步骤中处理批次数据和标签。 这样,我们就可以在PyTorch中成功读取并处理自定义数据集。通过这种方式,我们可以更好地利用PyTorch的功能来训练和评估自己的深度学习模型。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值