pytorch 《Writing Custom Datasets, Dataloaders and Transforms》官方指导 笔记

源码:

from __future__ import print_function, division
import os
import torch
import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils


# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

plt.ion()   # interactive mode


def show_landmarks(image, landmarks):
    """Show image with landmarks"""
    plt.imshow(image)
    plt.scatter(landmarks[:, 0], landmarks[:, 1], s=100, marker='.', c='r')


class FaceLandmarksDataset(Dataset):
    """Face Landmarks dataset."""

    # 初始化函数
    # 输入:csv文档对象路径、根目录路径、转换器对象
    def __init__(self, csv_file, root_dir, transform=None):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.landmarks_frame = pd.read_csv(csv_file)   # 把 csv 文档读到 Dataset 中
        self.root_dir = root_dir   # 在 Dataset 中保存一个根目录路径,根目录路径在 getitem() 的时候会用到
        self.transform = transform

    def __len__(self):
        return len(self.landmarks_frame)

    def __getitem__(self, idx):   # dataloader取数据的时候会直接给idx传进来一个int型数
        if torch.is_tensor(idx):   # 为什么要加这步判断?是iloc只支持list吗?
            idx = idx.tolist()

        img_name = os.path.join(self.root_dir, self.landmarks_frame.iloc[idx, 0])   # 把图像的路径拼接好 ?这一行单独拿出来会报错,见draft

        image = io.imread(img_name)   # 把图像读进来
        landmarks = self.landmarks_frame.iloc[idx, 1:]
        landmarks = np.asarray(landmarks)
        landmarks = landmarks.astype('float').reshape(-1, 2)
        sample = {'image': image, 'landmarks': landmarks}   # 把样本图像、landmarks 组合成一个字典型对象

        if self.transform:
            sample = self.transform(sample)   # 默认 transform 对象实现了 __call__(),变成了可调用的对象

        return sample


# # 创建 FaceLandmarksDataset 的一个实例并迭代样本,试着打印前4个样本的 size 并显示它们的 landmarks
face_dataset = FaceLandmarksDataset(csv_file='data/faces/face_landmarks.csv', root_dir='data/faces/')


class Rescale(object):
    """Rescale the image in a sample to a given size.

    Args:
        output_size (tuple or int): Desired output size. If tuple, output is
            matched to output_size. If int, smaller of image edges is matched
            to output_size keeping aspect ratio the same.
    """

    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        self.output_size = output_size

    def __call__(self, sample):
        image, landmarks = sample['image'], sample['landmarks']

        h, w = image.shape[:2]
        if isinstance(self.output_size, int):
            if h > w:
                new_h, new_w = self.output_size * h / w, self.output_size
            else:
                new_h, new_w = self.output_size, self.output_size * w / h
        else:
            new_h, new_w = self.output_size

        new_h, new_w = int(new_h), int(new_w)

        img = transform.resize(image, (new_h, new_w))

        # h and w are swapped for landmarks because for images,
        # x and y axes are axis 1 and 0 respectively
        landmarks = landmarks * [new_w / w, new_h / h]

        return {'image': img, 'landmarks': landmarks}


class RandomCrop(object):
    """Crop randomly the image in a sample.

    Args:
        output_size (tuple or int): Desired output size. If int, square crop
            is made.
    """

    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        if isinstance(output_size, int):
            self.output_size = (output_size, output_size)
        else:
            assert len(output_size) == 2
            self.output_size = output_size

    def __call__(self, sample):
        image, landmarks = sample['image'], sample['landmarks']

        h, w = image.shape[:2]
        new_h, new_w = self.output_size

        top = np.random.randint(0, h - new_h)
        left = np.random.randint(0, w - new_w)

        image = image[top: top + new_h, left: left + new_w]

        landmarks = landmarks - [left, top]   # 一定要注意 image 更新为剪裁后的图片以后必须更新 landmarks

        return {'image': image, 'landmarks': landmarks}


class ToTensor(object):
    """Convert ndarrays in sample to Tensors."""

    def __call__(self, sample):
        image, landmarks = sample['image'], sample['landmarks']

        # swap color axis because
        # numpy image: H x W x C
        # torch image: C X H X W
        image = image.transpose((2, 0, 1))
        return {'image': torch.from_numpy(image),
                'landmarks': torch.from_numpy(landmarks)}


transformed_dataset = FaceLandmarksDataset(csv_file='data/faces/face_landmarks.csv', root_dir='data/faces/', transform=transforms.Compose([Rescale(256), RandomCrop(224), ToTensor()]))

dataloader = DataLoader(transformed_dataset, batch_size=4, shuffle=True, num_workers=0)


# Helper function to show a batch
def show_landmarks_batch(sample_batched):
    """Show image with landmarks for a batch of samples."""
    images_batch, landmarks_batch = sample_batched['image'], sample_batched['landmarks']
    batch_size = len(images_batch)
    im_size = images_batch.size(2)
    grid_border_size = 2

    grid = utils.make_grid(images_batch)
    plt.imshow(grid.numpy().transpose((1, 2, 0)))

    for i in range(batch_size):
        plt.scatter(landmarks_batch[i, :, 0].numpy() + i * im_size + (i + 1) * grid_border_size,
                    landmarks_batch[i, :, 1].numpy() + grid_border_size, s=10, marker='.', c='r')
        plt.title('Batch from dataloader')


for i_batch, sample_batched in enumerate(dataloader):
    # 自定义的dataset对象里面,__getitem__()返回的是一个字典,
    # 这里4个样本组合起来用sample_batched变量接收到的还是一个字典,
    # 只不过字典里面的image键是4个图像数组和一起的一个4x3x224x224的大tensor,
    # landmarks键也是4个图像的landmark和一起的一个大tensor,sample_batched用起来很符合我们自定义dataset的定义。
    print(i_batch, sample_batched['image'].size(), sample_batched['landmarks'].size())

    # observe 4th batch and stop.
    if i_batch == 3:
        plt.figure()
        show_landmarks_batch(sample_batched)
        plt.axis('off')
        plt.ioff()
        plt.show()
        break

用到的API等:

路径连接:

os.path.join()

用路径读图像为一个3维数组:

skimage.io.imread("图像路径")

如果output_size是int型或tuple型则继续往下执行,否则抛出异常中断(assert关键字):

assert isinstance(output_size, (int, tuple))

调整图像大小以符合一定的尺寸:

skimage.transform.resize(图像, (,))

把对象1、对象2、…(我们自定义的transform类对象)组合成一个大transform,用于传给自定义dataset里做transform成员变量:

torchvision.transforms.Compose([对象1, 对象2, ...])

Make a grid of images:

torchvision.utils.make_grid(images_batch)
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值