Environment
- OS: macOS Mojave
- Python version: 3.7
- PyTorch version: 1.4.0
- IDE: PyCharm
文章目录
0. 写在前面
本文记录一下使用 PyTorch 读取图像数据。数据按照特定的目录结构放好之后,需要构建 Dataset 和 DataLoader 对数据进行读取。
Dataset 定义了读取数据的位置(data_dir
)和方式(__getitem__
),而 DataLoader 给出 indices 决定让 Dataset 读取哪些数据。
1. 构造 Dataset
以划分好的 TinyMind人民币面值识别 任务的训练集为例,目录结构如下
├── rmbdata
│ ├── categorise.py
│ ├── split.py
│ ├── test
│ │ ├── images.jpg
│ ├── train
│ │ ├── images.jpg
│ ├── train
│ │ ├── images.jpg
│ ├── train_face_value_label.csv
│ ├── val
│ │ ├── images.jpg
1.1 法一:定义 Dataset 子类
在 data 模块中定义一个 torch.utils.data.Dataset
的子类,对 __getitem__
和 __len__
方法进行重写,在训练或评估中导入使用
├── data.py
└── rmbdata
import os
from PIL import Image
from torch.utils.data import Dataset
class RMBFaceValueDataset(Dataset):
"""Dataset for classifying RMB by the face value"""
def __init__(self, data_dir, transform=None, class_to_idx=None):
"""
Params:
data_dir: str
directory of data
transform: torchvision.transform
data transform approaches
idx_to_class: dict
map class name to a specified number
"""
self.data_dir = data_dir
self.transform = transform
self.class_to_idx = class_to_idx
self.data_info = self._get_img_info()
def __getitem__(self, index):
"""
Receive an index and return an example with its label
Returns:
image: torch.Tensor or PIL.Image
image data
label: int
label for this example
"""
image_path, label = self.data_info[index]
image = Image.open(image_path).convert('RGB')
if self.transform is not None:
image = self.transform(image)
return image, label
def __len__(self):
return len(self.data_info)
def _get_img_info(self):
data_info = []