PyTorch 编写自定义数据集,数据加载器和转换

本文为 pytorch 官方教程https://pytorch.org/tutorials/beginner/data_loading_tutorial.html代码的注释

w3cschool 的翻译版本:https://www.w3cschool.cn/pytorch/pytorch-typm3be3.html

from __future__ import print_function, division
import os
import torch
import pandas as pd  # 用于解析 csv 文件
from skimage import io, transform  # 用于图像io和变换
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils

# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

# matplotlib 有两种工作模式:交互模式和阻塞模式;
# 使用 plt.ion() 打开交互模式 plt.plot(x) 或 plt.imshow(x) 直接打印图像,不需要 plt.show()
# 阻塞模式下 plt.plot(x) 或 plt.imshow(x) 需要 plt.show() 后才能显示图像
plt.ion()   
landmarks_frame = pd.read_csv('data/faces/face_landmarks.csv')  # 使用 pandas 读取 .csv文件

# 将 landmarks_frame 的第 65 行的标签重组成点的集合(方便打印为散点图)
n = 65
img_name = landmarks_frame.iloc[n, 0]
landmarks = landmarks_frame.iloc[n, 1:]
landmarks = np.asarray(landmarks)
landmarks = landmarks.astype('float').reshape(-1, 2)

# 打印第 65 张图片的名称
print('Image name: {}'.format(img_name))
# 打印第 65 张 图片的标签维度
print('Landmarks shape: {}'.format(landmarks.shape))
# 打印预览 4 个点标签
print('First 4 Landmarks: {}'.format(landmarks[:4]))
landmarks[:, 0].shape
# 编写一个简单的辅助函数来显示图像及其地标
def show_landmarks(image, landmarks):
    """Show image with landmarks"""
    plt.imshow(image)  # 打印图片
    plt.scatter(landmarks[:, 0], landmarks[:, 1], s=10, marker='.', c='r') # 打印散点图,横坐标和纵坐标都需要输入一维数据
    plt.pause(0.001)  # pause a bit so that plots are updated

plt.figure()  # 创建一个画板
show_landmarks(io.imread(os.path.join('data/faces/', img_name)), landmarks)  # 调用 show_landmarks 函数在画板上作画
plt.show()  # 显示画板上的图画
# 编写自定义数据集,使其满足 pytorch 数据集的标准
# __getitem__ 方法和 __len__ 方法是必须的,因为数据加载器 DataLoader 需要使用这两个方法实现数据规划功能
# __getitem__ 方法使得自定义数据集对象可以使用下标返回一个包含图片和标签的字典
# 数据集样本的格式:dict {'image': image, 'landmarks': landmarks}
class FaceLandmarksDataset(Dataset):
    """Face Landmarks dataset."""

    def __init__(self, csv_file, root_dir, transform=None):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.landmarks_frame = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.landmarks_frame)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_name = os.path.join(self.root_dir,
                                self.landmarks_frame.iloc[idx, 0])
        image = io.imread(img_name)
        # 将标签重组为坐标点的集合
        landmarks = self.landmarks_frame.iloc[idx, 1:]
        landmarks = np.array([landmarks])
        landmarks = landmarks.astype('float').reshape(-1, 2)
        sample = {'image': image, 'landmarks': landmarks}

        if self.transform:
            sample = self.transform(sample)

        return sample
# 打印前 4 个样本的大小并显示其地标
# 创建一个自定义数据集对象
face_dataset = FaceLandmarksDataset(csv_file='data/faces/face_landmarks.csv',
                                    root_dir='data/faces/')
# 创建一个画板对象
fig = plt.figure()

for i in range(len(face_dataset)):
    sample = face_dataset[i]
    # 打印图片的索引,图片的形状,标签的形状
    print(i, sample['image'].shape, sample['landmarks'].shape)
    # 在画板上创建一个子画布对象,画板被划分为 1 x 4 个画布
    ax = plt.subplot(1, 4, i + 1)
    plt.tight_layout()  # 自动调整子图参数,使之填充整个图像区域
    ax.set_title('Sample #{}'.format(i))  # 设置画布的 title
    ax.axis('off')  # 不要打印画布的坐标轴
    show_landmarks(**sample)  # 调用函数打印图形及地标,双星号可以用来获得字典的值(字典的索引名要符合函数的参数要求)
    # 相当于 show_landmarks(image=sample['image'],landmarks=sample['landmarks'])
    if i == 3:
        plt.show()
        break
## 自定义数据预处理工具,使其满足 pytorch 标准,__call__ 函数是必须的 (transforms.Compose 的要求)
# 缩放或拉伸图像以标准化数据集
# 使用 transform 的 resize 方法改变图像的大小
class Rescale(object):
    """Rescale the image in a sample to a given size.

    Args:
        output_size (tuple or int): Desired output size. If tuple, output is
            matched to output_size. If int, smaller of image edges is matched
            to output_size keeping aspect ratio the same.
    """

    def __init__(self, output_size):
        # assert 宏的原型定义在 assert.h 中,其作用是如果它的条件返回错误,则终止程序执行。
        # isinstance 判断 output_size 是否是 (int, tuple) 中的一个 
        assert isinstance(output_size, (int, tuple))
        self.output_size = output_size

    # __call__ 方法使得对象可以像函数一样被调用
    def __call__(self, sample):
        image, landmarks = sample['image'], sample['landmarks']
        
        # h 为图像的高(第一维),w 为宽(第二维)
        h, w = image.shape[:2]
        # 如果 output_size 是 int 数据类型,就将将图像保持宽高比放大
        # 如果 output_size 是 tuple 数据类型,图像宽高由 output_size 确定
        if isinstance(self.output_size, int):
            if h > w:
                new_h, new_w = self.output_size * h / w, self.output_size
            else:
                new_h, new_w = self.output_size, self.output_size * w / h
        else:
            new_h, new_w = self.output_size

        new_h, new_w = int(new_h), int(new_w)
        
        # 使用 transform 的 resize 方法改变图像的大小
        img = transform.resize(image, (new_h, new_w))

        # 标记点也要相应的等比例调整
        landmarks = landmarks * [new_w / w, new_h / h]

        return {'image': img, 'landmarks': landmarks}

# 随机裁剪图样以标准化数据集
# 使用切片方式裁剪图片
class RandomCrop(object):
    """Crop randomly the image in a sample.

    Args:
        output_size (tuple or int): Desired output size. If int, square crop
            is made.
    """

    def __init__(self, output_size):
        # 检查初始化参数的数据类型
        assert isinstance(output_size, (int, tuple))
        if isinstance(output_size, int):
            self.output_size = (output_size, output_size)
        else:
            # 检查字典的维度
            assert len(output_size) == 2
            self.output_size = output_size

    def __call__(self, sample):
        image, landmarks = sample['image'], sample['landmarks']

        h, w = image.shape[:2]
        new_h, new_w = self.output_size

        # randint 的取值区间是左闭右开的
        top = np.random.randint(0, h - new_h)
        left = np.random.randint(0, w - new_w)

        # 使用切片方式裁剪图片
        image = image[top: top + new_h,
                      left: left + new_w]

        landmarks = landmarks - [left, top]

        return {'image': image, 'landmarks': landmarks}
    
# 将 numpy 型数据转换为 tensor
# 使用 torch.from_numpy 方法实现
class ToTensor(object):
    """Convert ndarrays in sample to Tensors."""

    def __call__(self, sample):
        image, landmarks = sample['image'], sample['landmarks']

        # swap color axis because
        # numpy image: H x W x C
        # torch image: C X H X W
        image = image.transpose((2, 0, 1))
        return {'image': torch.from_numpy(image),
                'landmarks': torch.from_numpy(landmarks)}
# 创建数据预处理对象
# 等宽高比将图片最低维放缩到 256
scale = Rescale(256)
# 随机裁剪图片为 128 * 128 大小
crop = RandomCrop(128)
# 将两个图像预处理步骤组合
composed = transforms.Compose([Rescale(256),
                               RandomCrop(224)])

# 应用图像预处理到 65 号索引的图像
# 创建画板
fig = plt.figure()
sample = face_dataset[65]
for i, tsfrm in enumerate([scale, crop, composed]):
    transformed_sample = tsfrm(sample)
    # 使用画布分割画板
    ax = plt.subplot(1, 3, i + 1)
    plt.tight_layout()
    ax.set_title(type(tsfrm).__name__)
    # 在画布上绘图
    show_landmarks(**transformed_sample)
# 显示绘制完成的画板
plt.show()
# 生成经过预处理的 dataset 数据集
transformed_dataset = FaceLandmarksDataset(csv_file='data/faces/face_landmarks.csv',
                                           root_dir='data/faces/',
                                           transform=transforms.Compose([
                                               Rescale(256),
                                               RandomCrop(224),
                                               ToTensor()
                                           ]))
# 打印 4 组数据的维度,检查预处理效果
for i in range(len(transformed_dataset)):
    sample = transformed_dataset[i]

    print(i, sample['image'].size(), sample['landmarks'].size())

    if i == 3:
        break
dataloader = DataLoader(transformed_dataset, batch_size=4,
                        shuffle=True, num_workers=4)

# 编写函数打印一个 batch 的图像
def show_landmarks_batch(sample_batched):
    """Show image with landmarks for a batch of samples."""
    images_batch, landmarks_batch = \
            sample_batched['image'], sample_batched['landmarks']
    batch_size = len(images_batch)
    im_size = images_batch.size(2)
    grid_border_size = 2

    grid = utils.make_grid(images_batch)
    plt.imshow(grid.numpy().transpose((1, 2, 0)))

    for i in range(batch_size):
        plt.scatter(landmarks_batch[i, :, 0].numpy() + i * im_size + (i + 1) * grid_border_size,
                    landmarks_batch[i, :, 1].numpy() + grid_border_size,
                    s=10, marker='.', c='r')

        plt.title('Batch from dataloader')

for i_batch, sample_batched in enumerate(dataloader):
    # 查看分 batch 后的数据维度
    print(i_batch, sample_batched['image'].size(),
          sample_batched['landmarks'].size())
    # observe 4th batch and stop.
    if i_batch == 3:
        plt.figure()
        show_landmarks_batch(sample_batched)
        plt.axis('off')
        plt.ioff()
        plt.show()
        break
        break

 

  • 1
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值