Pytorch自定义数据集(Custom Dataset)的读取方式

82 篇文章 0 订阅

相关模块:torchvision

torchvision 是独立于pytorch 之外的图像操作库
具体介绍详见:DrHW的文章

torchvision主要包括一下几个包:1

  • torchvision.datasets : 几个常用视觉数据集,可以下载和加载这里主要的高级用法就是可以看源码如何自己写自己的Dataset的子类
    这部分就是本文要介绍的重点
  • torchvision.models: 流行的模型,例如 AlexNet, VGG, ResNet 和 Densenet 以及 与训练好的参数。
  • torchvision.transforms : 常用的图像操作,例如:随机切割,旋转,数据类型转换,图像到tensor ,numpy 数组到tensor , tensor 到 图像等。
  • torchvision.utils : 用于把形似 (3 x H x W) 的张量保存到硬盘中,给一个mini-batch的图像可以产生一个图像格网。
    shape = (channel, height, weight)

具体操作

  • 自定义数据集的基础方法:

引文2

"""
inout pipline for custom dataset
"""
from torch.utils.data.dataset import Dataset
class CustomDataset(Dataset):
    def __init__(self):
    	"""
    	一些初始化过程写在这里
    	"""
        # TODO
        # 1. Initialize file paths or a list of file names. 
        pass
    def __getitem__(self, index):
    	"""
    	返回数据和标签,可以这样显示调用:
    	img, label = MyCustomDataset.__getitem__(99)
    	"""
        # TODO
        # 1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open).
        # 2. Preprocess the data (e.g. torchvision.Transform).
        # 3. Return a data pair (e.g. image and label).
        pass
    def __len__(self):
    	"""
    	返回所有数据的数量
    	"""
        # You should change 9 to the total size of your dataset.
        return 9 # e.g. 9 is size of dataset
使用 Torchvision Transforms
  • 方法一:
from torch.utils.data.dataset import Dataset
from torchvision import transforms

class MyCustomDataset(Dataset):
    def __init__(self, ..., transforms=None):
        # stuff
        ...
        self.transforms = transforms
        
    def __getitem__(self, index):
        # stuff
        ...
        data = # 一些读取的数据
        if self.transforms is not None:
            data = self.transforms(data)
        # 如果 transform 不为 None,则进行 transform 操作
        return (img, label)
 
    def __len__(self):
        return count 
        
if __name__ == \'__main__\':
    # 定义我们的 transforms (1)
    transformations = transforms.Compose([transforms.CenterCrop(100), transforms.ToTensor()])
    # 创建 dataset
    custom_dataset = MyCustomDataset(..., transformations)
  • 方法二:
    有些人不喜欢将transform写在Dataset外, 即在Dataset内定义transform

from torch.utils.data.dataset import Dataset
from torchvision import transforms
 
class MyCustomDataset(Dataset):
    def __init__(self, ...):
        # stuff
        ...
        # (2) 一种方法是单独定义 transform
        self.center_crop = transforms.CenterCrop(100)
        self.to_tensor = transforms.ToTensor()
        
        # (3) 或者写成下面这样 
        self.transformations = \
            transforms.Compose([transforms.CenterCrop(100),
                                transforms.ToTensor()])
        
    def __getitem__(self, index):
        # stuff
        ...
        data = #一些读取的数据
        
        # 当第二次调用 transform 时,调用的是 __call__()
        data = self.center_crop(data)  # (2)
        data = self.to_tensor(data)  # (2)
        
        # 或者写成下面这样
        data = self.trasnformations(data)  # (3)
        
        # 注意 (2) 和 (3) 中只需要实现一种
        return (img, label)
 
    def __len__(self):
        return count
        
if __name__ == \'__main__\':
    custom_dataset = MyCustomDataset(...)

结合 Pandas 使用 getitem()

另一种情况是 csv 文件中保存了我们需要的图像文件的像素值(比如有些 MNIST 教程就是这样的)。我们需要改动一下 getitem() 函数。

Labelpixel_1pixel_2
15099
021223
944112
class CustomDatasetFromCSV(Dataset):
    def __init__(self, csv_path, height, width, transforms=None):
        """
        Args:
            csv_path (string): csv 文件路径
            height (int): 图像高度
            width (int): 图像宽度
            transform: transform 操作
        """
        self.data = pd.read_csv(csv_path)
        self.labels = np.asarray(self.data.iloc[:, 0])
        self.height = height
        self.width = width
        self.transforms = transform
 
    def __getitem__(self, index):
        single_image_label = self.labels[index]
        # 读取所有像素值,并将 1D array ([784]) reshape 成为 2D array ([28,28]) 
        img_as_np = np.asarray(self.data.iloc[index][1:]).reshape(28,28).astype(\'uint8\')
	# 把 numpy array 格式的图像转换成灰度 PIL image
        img_as_img = Image.fromarray(img_as_np)
        img_as_img = img_as_img.convert(\'L\')
        # 将图像转换成 tensor
        if self.transforms is not None:
            img_as_tensor = self.transforms(img_as_img)
        # 返回图像及其 label
        return (img_as_tensor, single_image_label)
 
    def __len__(self):
        return len(self.data.index)
        
 
if __name__ == "__main__":
    transformations = transforms.Compose([transforms.ToTensor()])
    custom_mnist_from_csv = \
        CustomDatasetFromCSV(\'../data/mnist_in_csv.csv\', 28, 28, transformations)

使用 Dataloader 读取自定义数据集

PyTorch 中的 Dataloader 只是调用 getitem() 方法并组合成 batch,我们可以这样调用:


...
if __name__ == "__main__":
    # 定义 transforms
    transformations = transforms.Compose([transforms.ToTensor()])
    # 自定义数据集
    custom_mnist_from_csv = \
        CustomDatasetFromCSV(\'../data/mnist_in_csv.csv\',
                             28, 28,
                             transformations)
    # 定义 data loader
    mn_dataset_loader = torch.utils.data.DataLoader(dataset=custom_mnist_from_csv,
                                                    batch_size=10,
                                                    shuffle=False)
    
    for images, labels in mn_dataset_loader:
        # 将数据传给网络模型 

需要注意的是使用多卡训练时,PyTorch dataloader 会将每个 batch 平均分配到各个 GPU。所以如果 batch size 过小,可能发挥不了多卡的效果。

Stanford Dogs 数据集自定义实例

from torch.utils.data.dataset import Dataset
from torchvision import transforms

class MyDateset(Dataset):
    def __init__(self, file_folder, is_test=False, transform=None):
        self.img_folder_path = '../input/images/Images/'
        self.annotation_folder_path = '../input/annotations/Annotation/'
        self.file_folder = file_folder
        self.transform = transform
        #self.transform = transforms.Compose
        self.is_test = is_test
        
    def __getitem__(self, idx):
        file = self.file_folder[idx]
        img_path = self.img_folder_path + file
        img = Image.open(img_path).convert('RGB')
        
        if not self.is_test:
            annotation_path = self.annotation_folder_path + file.split('.')[0]
            with open(annotation_path) as f:
                annotation = f.read()

            xy = self.get_xy(annotation)
            box = torch.FloatTensor(list(xy))

            new_box = self.box_resize(box, img)
            if self.transform is not None:
                img = self.transform(img)

            return img, new_box
        else:
            if self.transform is not None:
                img = self.transform(img)
            return img
    
    def __len__(self):
        return len(self.file_folder)
        
    def get_xy(self, annotation):
        xmin = int(re.findall('(?<=<xmin>)[0-9]+?(?=</xmin>)', annotation)[0])
        xmax = int(re.findall('(?<=<xmax>)[0-9]+?(?=</xmax>)', annotation)[0])
        ymin = int(re.findall('(?<=<ymin>)[0-9]+?(?=</ymin>)', annotation)[0])
        ymax = int(re.findall('(?<=<ymax>)[0-9]+?(?=</ymax>)', annotation)[0])
        
        return xmin, ymin, xmax, ymax
    
    def show_box(self):
        file = random.choice(self.file_folder)
        annotation_path = self.annotation_folder_path + file.split('.')[0]
        
        img_box = Image.open(self.img_folder_path + file)
        with open(annotation_path) as f:
            annotation = f.read()
            
        draw = ImageDraw.Draw(img_box)
        xy = self.get_xy(annotation)
        print('bbox:', xy)
        draw.rectangle(xy=[xy[:2], xy[2:]])
        
        return img_box
        
    def box_resize(self, box, img, dims=(332, 332)):
        old_dims = torch.FloatTensor([img.width, img.height, img.width, img.height]).unsqueeze(0)
        new_box = box / old_dims
        new_dims = torch.FloatTensor([dims[1], dims[0], dims[1], dims[0]]).unsqueeze(0)
        new_box = new_box * new_dims
        
        return new_box

FaceLandmarks实例

class FaceLandmarksDataset(Dataset):
    """Face Landmarks dataset."""

    def __init__(self, csv_file, root_dir, transform=None):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.landmarks_frame = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.landmarks_frame)

    def __getitem__(self, idx):
        img_name = os.path.join(self.root_dir,
                                self.landmarks_frame.iloc[idx, 0])
        image = io.imread(img_name)
        landmarks = self.landmarks_frame.iloc[idx, 1:].as_matrix()
        landmarks = landmarks.astype('float').reshape(-1, 2)
        sample = {'image': image, 'landmarks': landmarks}

        if self.transform:
            sample = self.transform(sample)

        return sample

参考文献:


  1. https://www.cnblogs.com/yjphhw/p/9773333.html ↩︎

  2. https://github.com/yunjey/pytorch-tutorial/ ↩︎

  • 9
    点赞
  • 45
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
PyTorch 中,自定义数据集可以通过继承 `torch.utils.data.Dataset` 类来实现。这个类需要实现两个方法:`__len__` 和 `__getitem__`。 `__len__` 方法返回数据集的长度,即样本数量。`__getitem__` 方法返回数据集中一个索引对应的样本。 下面是一个简单的例子,假设我们有一个文件夹 `data`,里面包含若干张图片和对应的标签,我们要把这个数据集PyTorch 加载起来: ```python import os from PIL import Image import torch.utils.data as data class CustomDataset(data.Dataset): def __init__(self, root_dir): self.root_dir = root_dir self.img_list = os.listdir(root_dir) def __len__(self): return len(self.img_list) def __getitem__(self, index): img_path = os.path.join(self.root_dir, self.img_list[index]) img = Image.open(img_path).convert('RGB') label = int(self.img_list[index].split('_')[0]) return img, label ``` 在上面的例子中,我们定义了一个 `CustomDataset` 类,它有一个构造函数 `__init__`,接收一个参数 `root_dir` 表示数据集所在的文件夹路径。`__init__` 方法初始化了 `img_list` 属性,里面保存了所有图片文件名。 `__len__` 方法返回了 `img_list` 的长度,即数据集中样本的数量。 `__getitem__` 方法接收一个索引 `index`,返回了数据集中第 `index` 个样本的图片和标签。具体地,它首先获取了图片文件的路径,然后用 `PIL` 库打开图片并转换成 RGB 模式。最后,它从文件名中解析出标签信息,并把图片和标签一起返回。 有了这个自定义数据集类,我们就可以用 PyTorch 的 `DataLoader` 类来加载数据集了。例如: ```python import torch.utils.data as data dataset = CustomDataset('data') dataloader = data.DataLoader(dataset, batch_size=32, shuffle=True) ``` 在上面的例子中,我们创建了一个 `CustomDataset` 对象 `dataset`,然后用 `DataLoader` 类来初始化 `dataloader` 对象。`DataLoader` 的第一个参数是数据集对象,第二个参数是批量大小,第三个参数是是否打乱数据集顺序。
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值