PyTorch 中自定义数据集的读取方法小结

PyTorch 中自定义数据集的读取方法小结

作者: PyTorch 中文网 发布: 2018年9月20日 5,922阅读 0评论

虽然说网上关于 PyTorch 数据集读取的文章和教程多的很,但总觉得哪里不对,尤其是对新手来说,可能需要很长一段时间来钻研和尝试。所以这里我们 PyTorch 中文网为大家总结常用的几种自定义数据集(Custom Dataset)的读取方式(采用 Dataloader)。

本文将涉及以下几个方面:

  • 自定义数据集基础方法
  • 使用 Torchvision Transforms
  • 换一种方法使用 Torchvision Transforms
  • 结合 Pandas 读取 csv 文件
  • 结合 Pandas 使用 __getitem__()
  • 使用 Dataloader 读取自定义数据集

文章目录 [隐藏]

自定义数据集基础方法

首先要创建一个 Dataset 类:

from torch.utils.data.dataset import Dataset class MyCustomDataset(Dataset): def __init__(self, ...): # stuff def __getitem__(self, index): # stuff return (img, label) def __len__(self): return count

1

2

3

4

5

6

7

8

9

10

11

12

from torch.utils.data.dataset import Dataset

 

class MyCustomDataset(Dataset):

    def __init__(self, ...):

        # stuff

        

    def __getitem__(self, index):

        # stuff

        return (img, label)

 

    def __len__(self):

        return count

这个代码中:

  • __init__() 一些初始化过程写在这里
  • __len__() 返回所有数据的数量
  • __getitem__() 返回数据和标签,可以这样显示调用:

img, label = MyCustomDataset.__getitem__(99)

1

img, label = MyCustomDataset.__getitem__(99)

使用 Torchvision Transforms

Transform 最常见的使用方法是:

from torch.utils.data.dataset import Dataset from torchvision import transforms class MyCustomDataset(Dataset): def __init__(self, ..., transforms=None): # stuff ... self.transforms = transforms def __getitem__(self, index): # stuff ... data = # 一些读取的数据 if self.transforms is not None: data = self.transforms(data) # 如果 transform 不为 None,则进行 transform 操作 return (img, label) def __len__(self): return count if __name__ == \'__main__\': # 定义我们的 transforms (1) transformations = transforms.Compose([transforms.CenterCrop(100), transforms.ToTensor()]) # 创建 dataset custom_dataset = MyCustomDataset(..., transformations)

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

from torch.utils.data.dataset import Dataset

from torchvision import transforms

 

class MyCustomDataset(Dataset):

    def __init__(self, ..., transforms=None):

        # stuff

        ...

        self.transforms = transforms

        

    def __getitem__(self, index):

        # stuff

        ...

        data = # 一些读取的数据

        if self.transforms is not None:

            data = self.transforms(data)

        # 如果 transform 不为 None,则进行 transform 操作

        return (img, label)

 

    def __len__(self):

        return count

        

if __name__ == \'__main__\':

    # 定义我们的 transforms (1)

    transformations = transforms.Compose([transforms.CenterCrop(100), transforms.ToTensor()])

    # 创建 dataset

    custom_dataset = MyCustomDataset(..., transformations)

 

换一种方法使用 Torchvision Transforms

有些人不喜欢把 transform 操作写在 Dataset 外面(上面代码里的注释 1),所以还有一种写法:

from torch.utils.data.dataset import Dataset from torchvision import transforms class MyCustomDataset(Dataset): def __init__(self, ...): # stuff ... # (2) 一种方法是单独定义 transform self.center_crop = transforms.CenterCrop(100) self.to_tensor = transforms.ToTensor() # (3) 或者写成下面这样 self.transformations = \ transforms.Compose([transforms.CenterCrop(100), transforms.ToTensor()]) def __getitem__(self, index): # stuff ... data = #一些读取的数据 # 当第二次调用 transform 时,调用的是 __call__() data = self.center_crop(data) # (2) data = self.to_tensor(data) # (2) # 或者写成下面这样 data = self.trasnformations(data) # (3) # 注意 (2) 和 (3) 中只需要实现一种 return (img, label) def __len__(self): return count if __name__ == \'__main__\': custom_dataset = MyCustomDataset(...)

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

from torch.utils.data.dataset import Dataset

from torchvision import transforms

 

class MyCustomDataset(Dataset):

    def __init__(self, ...):

        # stuff

        ...

        # (2) 一种方法是单独定义 transform

        self.center_crop = transforms.CenterCrop(100)

        self.to_tensor = transforms.ToTensor()

        

        # (3) 或者写成下面这样

        self.transformations = \

            transforms.Compose([transforms.CenterCrop(100),

                                transforms.ToTensor()])

        

    def __getitem__(self, index):

        # stuff

        ...

        data = #一些读取的数据

        

        # 当第二次调用 transform 时,调用的是 __call__()

        data = self.center_crop(data)  # (2)

        data = self.to_tensor(data)  # (2)

        

        # 或者写成下面这样

        data = self.trasnformations(data)  # (3)

        

        # 注意 (2) 和 (3) 中只需要实现一种

        return (img, label)

 

    def __len__(self):

        return count

        

if __name__ == \'__main__\':

    custom_dataset = MyCustomDataset(...)

 

结合 Pandas 读取 csv 文件

假如说我们想从一个 csv 文件中用 Pandas 读取数据。一个 csv 示例如下:

File NameLabelExtra Operation
tr_0.png5TRUE
tr_1.png0FALSE
tr_1.png4FALSE

如果我们需要在自定义数据集里从这个 csv 文件读取文件名,可以这样做:

class CustomDatasetFromImages(Dataset): def __init__(self, csv_path): """ Args: csv_path (string): csv 文件路径 img_path (string): 图像文件所在路径 transform: transform 操作 """ # Transforms self.to_tensor = transforms.ToTensor() # 读取 csv 文件 self.data_info = pd.read_csv(csv_path, header=None) # 文件第一列包含图像文件的名称 self.image_arr = np.asarray(self.data_info.iloc[:, 0]) # 第二列是图像的 label self.label_arr = np.asarray(self.data_info.iloc[:, 1]) # 第三列是决定是否进行额外操作 self.operation_arr = np.asarray(self.data_info.iloc[:, 2]) # 计算 length self.data_len = len(self.data_info.index) def __getitem__(self, index): # 从 pandas df 中得到文件名 single_image_name = self.image_arr[index] # 读取图像文件 img_as_img = Image.open(single_image_name) # 检查需不需要额外操作 some_operation = self.operation_arr[index] # 如果需要额外操作 if some_operation: # ... # ... pass # 把图像转换成 tensor img_as_tensor = self.to_tensor(img_as_img) # 得到图像的 label single_image_label = self.label_arr[index] return (img_as_tensor, single_image_label) def __len__(self): return self.data_len if __name__ == "__main__": custom_mnist_from_images = \ CustomDatasetFromImages(\'../data/mnist_labels.csv\')

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

37

38

39

40

41

42

43

44

45

46

47

48

class CustomDatasetFromImages(Dataset):

    def __init__(self, csv_path):

        """

        Args:

            csv_path (string): csv 文件路径

            img_path (string): 图像文件所在路径

            transform: transform 操作

        """

        # Transforms

        self.to_tensor = transforms.ToTensor()

        # 读取 csv 文件

        self.data_info = pd.read_csv(csv_path, header=None)

        # 文件第一列包含图像文件的名称

        self.image_arr = np.asarray(self.data_info.iloc[:, 0])

        # 第二列是图像的 label

        self.label_arr = np.asarray(self.data_info.iloc[:, 1])

        # 第三列是决定是否进行额外操作

        self.operation_arr = np.asarray(self.data_info.iloc[:, 2])

        # 计算 length

        self.data_len = len(self.data_info.index)

 

    def __getitem__(self, index):

        # 从 pandas df 中得到文件名

        single_image_name = self.image_arr[index]

        # 读取图像文件

        img_as_img = Image.open(single_image_name)

 

        # 检查需不需要额外操作

        some_operation = self.operation_arr[index]

        # 如果需要额外操作

        if some_operation:

            # ...

            # ...

            pass

        # 把图像转换成 tensor

        img_as_tensor = self.to_tensor(img_as_img)

 

        # 得到图像的 label

        single_image_label = self.label_arr[index]

 

        return (img_as_tensor, single_image_label)

 

    def __len__(self):

        return self.data_len

 

if __name__ == "__main__":

    custom_mnist_from_images =  \

        CustomDatasetFromImages(\'../data/mnist_labels.csv\')

 

结合 Pandas 使用 __getitem__()

另一种情况是 csv 文件中保存了我们需要的图像文件的像素值(比如有些 MNIST 教程就是这样的)。我们需要改动一下 __getitem__() 函数。

Labelpixel_1pixel_2
15099
021223
944112

代码如下:

class CustomDatasetFromCSV(Dataset): def __init__(self, csv_path, height, width, transforms=None): """ Args: csv_path (string): csv 文件路径 height (int): 图像高度 width (int): 图像宽度 transform: transform 操作 """ self.data = pd.read_csv(csv_path) self.labels = np.asarray(self.data.iloc[:, 0]) self.height = height self.width = width self.transforms = transform def __getitem__(self, index): single_image_label = self.labels[index] # 读取所有像素值,并将 1D array ([784]) reshape 成为 2D array ([28,28]) img_as_np = np.asarray(self.data.iloc[index][1:]).reshape(28,28).astype(\'uint8\') # 把 numpy array 格式的图像转换成灰度 PIL image img_as_img = Image.fromarray(img_as_np) img_as_img = img_as_img.convert(\'L\') # 将图像转换成 tensor if self.transforms is not None: img_as_tensor = self.transforms(img_as_img) # 返回图像及其 label return (img_as_tensor, single_image_label) def __len__(self): return len(self.data.index) if __name__ == "__main__": transformations = transforms.Compose([transforms.ToTensor()]) custom_mnist_from_csv = \ CustomDatasetFromCSV(\'../data/mnist_in_csv.csv\', 28, 28, transformations)

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

class CustomDatasetFromCSV(Dataset):

    def __init__(self, csv_path, height, width, transforms=None):

        """

        Args:

            csv_path (string): csv 文件路径

            height (int): 图像高度

            width (int): 图像宽度

            transform: transform 操作

        """

        self.data = pd.read_csv(csv_path)

        self.labels = np.asarray(self.data.iloc[:, 0])

        self.height = height

        self.width = width

        self.transforms = transform

 

    def __getitem__(self, index):

        single_image_label = self.labels[index]

        # 读取所有像素值,并将 1D array ([784]) reshape 成为 2D array ([28,28])

        img_as_np = np.asarray(self.data.iloc[index][1:]).reshape(28,28).astype(\'uint8\')

# 把 numpy array 格式的图像转换成灰度 PIL image

        img_as_img = Image.fromarray(img_as_np)

        img_as_img = img_as_img.convert(\'L\')

        # 将图像转换成 tensor

        if self.transforms is not None:

            img_as_tensor = self.transforms(img_as_img)

        # 返回图像及其 label

        return (img_as_tensor, single_image_label)

 

    def __len__(self):

        return len(self.data.index)

        

 

if __name__ == "__main__":

    transformations = transforms.Compose([transforms.ToTensor()])

    custom_mnist_from_csv = \

        CustomDatasetFromCSV(\'../data/mnist_in_csv.csv\', 28, 28, transformations)

使用 Dataloader 读取自定义数据集

PyTorch 中的 Dataloader 只是调用 __getitem__() 方法并组合成 batch,我们可以这样调用:

... if __name__ == "__main__": # 定义 transforms transformations = transforms.Compose([transforms.ToTensor()]) # 自定义数据集 custom_mnist_from_csv = \ CustomDatasetFromCSV(\'../data/mnist_in_csv.csv\', 28, 28, transformations) # 定义 data loader mn_dataset_loader = torch.utils.data.DataLoader(dataset=custom_mnist_from_csv, batch_size=10, shuffle=False) for images, labels in mn_dataset_loader: # 将数据传给网络模型

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

...

if __name__ == "__main__":

    # 定义 transforms

    transformations = transforms.Compose([transforms.ToTensor()])

    # 自定义数据集

    custom_mnist_from_csv = \

        CustomDatasetFromCSV(\'../data/mnist_in_csv.csv\',

                             28, 28,

                             transformations)

    # 定义 data loader

    mn_dataset_loader = torch.utils.data.DataLoader(dataset=custom_mnist_from_csv,

                                                    batch_size=10,

                                                    shuffle=False)

    

    for images, labels in mn_dataset_loader:

        # 将数据传给网络模型

需要注意的是使用多卡训练时,PyTorch dataloader 会将每个 batch 平均分配到各个 GPU。所以如果 batch size 过小,可能发挥不了多卡的效果。

### 回答1: 在 PyTorch 读取自定义数据集的一般步骤如下: 1. 定义数据集类:首先需要定义一个数据集类,继承自 `torch.utils.data.Dataset` 类,并实现 `__getitem__` 和 `__len__` 方法。在 `__getitem__` 方法,根据索引返回一个样本的数据和标签。 2. 加载数据集:使用 `torch.utils.data.DataLoader` 类加载数据集,可以设置批量大小、多线程读取数据等参数。 下面是一个简单的示例代码,演示如何使用 PyTorch 读取自定义数据集: ```python import torch from torch.utils.data import Dataset, DataLoader class CustomDataset(Dataset): def __init__(self, data, targets): self.data = data self.targets = targets def __getitem__(self, index): x = self.data[index] y = self.targets[index] return x, y def __len__(self): return len(self.data) # 加载训练集和测试集 train_data = ... train_targets = ... train_dataset = CustomDataset(train_data, train_targets) train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) test_data = ... test_targets = ... test_dataset = CustomDataset(test_data, test_targets) test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False) # 训练模型 for epoch in range(num_epochs): for batch_idx, (data, targets) in enumerate(train_loader): # 前向传播、反向传播,更新参数 ... ``` 在上面的示例代码,我们定义了一个 `CustomDataset` 类,加载了训练集和测试集,并使用 `DataLoader` 类分别对它们进行批量读取。在训练模型时,我们可以像使用 PyTorch 自带的数据集一样,循环遍历每个批次的数据和标签,进行前向传播、反向传播等操作。 ### 回答2: PyTorch是一个开源的深度学习框架,它提供了丰富的功能用于读取和处理自定义数据集。下面是一个简单的步骤来读取自定义数据集。 首先,我们需要定义一个自定义数据集类,该类应继承自`torch.utils.data.Dataset`类,并实现`__len__`和`__getitem__`方法。`__len__`方法应返回数据集的样本数量,`__getitem__`方法根据给定索引返回一个样本。 ```python import torch from torch.utils.data import Dataset class CustomDataset(Dataset): def __init__(self, data): self.data = data def __len__(self): return len(self.data) def __getitem__(self, idx): sample = self.data[idx] return torch.tensor(sample) ``` 接下来,我们可以创建一个数据集实例并传入自定义数据。假设我们有一个包含多个样本的列表 `data`。 ```python data = [[1, 2, 3], [4, 5, 6], [7, 8, 9]] dataset = CustomDataset(data) ``` 然后,我们可以使用`torch.utils.data.DataLoader`类加载数据集,并指定批次大小、是否打乱数据等。 ```python batch_size = 2 dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True) ``` 现在,我们可以迭代数据加载器来获取批次的样本。 ```python for batch in dataloader: print(batch) ``` 上面的代码将打印出两个批次的样本。如果`shuffle`参数设置为`True`,则每个批次的样本将是随机的。 总而言之,PyTorch提供了简单而强大的工具来读取和处理自定义数据集,可以根据实际情况进行适当修改和扩展。 ### 回答3: PyTorch是一个流行的深度学习框架,可以用来训练神经网络模型。要使用PyTorch读取自定义数据集,可以按照以下几个步骤进行: 1. 准备数据集:将自定义数据集组织成合适的目录结构。通常情况下,可以将数据集分为训练集、验证集和测试集,每个集合分别放在不同的文件夹。确保每个文件夹的数据按照类别进行分类,以便后续的标签处理。 2. 创建数据加载器:在PyTorch,数据加载器是一个有助于有效读取和处理数据的类。可以使用`torchvision.datasets.ImageFolder`类创建一个数据加载器对象,通过传入数据集的目录路径来实现。 3. 数据预处理:在将数据传入模型之前,可能需要对数据进行一些预处理操作,例如图像变换、标准化或归一化等。可以使用`torchvision.transforms`的类来实现这些预处理操作,然后将它们传入数据加载器。 4. 创建数据迭代器:数据迭代器是连接数据集和模型的重要接口,它提供了一个逐批次加载数据的功能。可以使用`torch.utils.data.DataLoader`类创建数据迭代器对象,并设置一些参数,例如批量大小、是否打乱数据等。 5. 使用数据迭代器:在训练时,可以使用Python的迭代器来遍历数据集并加载数据。通常,它会在每个迭代步骤返回一个批次的数据和标签。可以通过`for`循环来遍历数据迭代器,并在每个步骤处理批次数据和标签。 这样,我们就可以在PyTorch成功读取并处理自定义数据集。通过这种方式,我们可以更好地利用PyTorch的功能来训练和评估自己的深度学习模型。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值