PyTorch之数据加载和处理

学习如何构造和使用数据集类(datasets),转换(transforms)和数据加载器(dataloader)。

参考数据加载和处理 - PyTorch官方教程中文版

from __future__ import print_function, division  # 执行精准除法
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = 'TRUE'
import numpy as np
import pandas as pd  # 用于更容易地进行csv解析
from skimage import io, transform  # 用于图像的IO和变换
import matplotlib.pyplot as plt
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils, datasets
import warnings
warnings.filterwarnings("ignore")


def show_landmarks(image, landmarks):
    """显示带有地标的图片"""
    plt.imshow(image)
    plt.scatter(landmarks[:, 0], landmarks[:, 1], s=10, marker='.', c='r')
    plt.pause(0.001)


def show_landmarks_batch(sample_batch):
    """Show image with landmarks for a batch of samples."""
    images_batch, landmarks_batch = sample_batch['image'], sample_batch['landmarks']
    batch_size = len(images_batch)
    im_size = images_batch.size(2)
    grid_border_size = 2
    grid = utils.make_grid(images_batch)
    plt.imshow(grid.numpy().transpose((1, 2, 0)))
    for idx in range(batch_size):
        plt.scatter(landmarks_batch[idx, :, 0].numpy() + idx * im_size + (idx + 1) * grid_border_size,
                    landmarks_batch[idx, :, 1].numpy() + grid_border_size, s=10, marker='.', c='r')
        plt.title('Batch from data_loader')


class FaceLandmarksDataset(Dataset):
    """面部标记数据集"""
    def __init__(self, csv_file, root_dir, transform=None):
        """
        csv_file(string):带注释的csv文件的路径。
        root_dir(string):包含所有图像的目录。
        transform(callable, optional):一个样本上的可用的可选变换
        """
        self.landmarks_frame = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.landmarks_frame)

    def __getitem__(self, idx):
        img_name = os.path.join(self.root_dir, self.landmarks_frame.iloc[idx, 0])
        image = io.imread(img_name)
        landmarks = self.landmarks_frame.iloc[idx, 1:]
        landmarks = np.array([landmarks])
        landmarks = landmarks.astype('float').reshape(-1, 2)
        _sample = {'image': image, 'landmarks': landmarks}
        if self.transform:
            _sample = self.transform(_sample)
        return _sample


class Rescale(object):
    """将样本中的图像重新缩放到给定大小.
    Args:
        output_size(tuple或int):所需的输出大小。如果是tuple,则输出为与output_size匹配。
            如果是int,则匹配较小的图像边缘到output_size保持纵横比相同。
    example:
        input=(640,600)  output_size=480  output=(640*480/600,480)=(512,480)
    """
    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        self.output_size = output_size

    def __call__(self, _sample):
        image, landmarks = _sample['image'], _sample['landmarks']  # image  landmarks
        h, w = image.shape[:2]  # height width
        if isinstance(self.output_size, int):
            if h > w:
                new_h, new_w = self.output_size * h / w, self.output_size
            else:
                new_h, new_w = self.output_size, self.output_size * w / h
        else:
            new_h, new_w = self.output_size
        new_h, new_w = int(new_h), int(new_w)
        img = transform.resize(image, (new_h, new_w))
        # x and y axes are axis 1 and 0 respectively, landmark=[x,y]=[col,row]=[axis1,axis0]
        landmarks = landmarks * [new_w / w, new_h / h]
        return {'image': img, 'landmarks': landmarks}


class RandomCrop(object):
    """随机裁剪样本中的图像.
    Args:
       output_size(tuple或int):所需的输出大小。 如果是int,方形裁剪是。
    """
    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        if isinstance(output_size, int):
            self.output_size = (output_size, output_size)
        else:
            assert len(output_size) == 2
            self.output_size = output_size

    def __call__(self, _sample):
        image, landmarks = _sample['image'], _sample['landmarks']
        h, w = image.shape[:2]
        new_h, new_w = self.output_size
        top = np.random.randint(0, h - new_h)
        left = np.random.randint(0, w - new_w)
        image = image[top: top + new_h, left: left + new_w]
        landmarks = landmarks - [left, top]
        return {'image': image, 'landmarks': landmarks}


class ToTensor(object):
    """将样本中的ndarrays(多维数组)转换为Tensors"""
    def __call__(self, sample):
        image, landmarks = sample['image'], sample['landmarks']
        # 交换颜色轴因为
        # numpy包的图片是: H * W * C
        # torch包的图片是: C * H * W
        image = image.transpose((2, 0, 1))  # 通道转换
        return {'image': torch.from_numpy(image), 'landmarks': torch.from_numpy(landmarks)}  # 转为torch的tensor格式


if __name__ == "__main__":

    print('########### 展示图片 ##########')
    landmarks_frames = pd.read_csv('data/faces/face_landmarks.csv')
    index = 65
    src_name = landmarks_frames.iloc[index, 0]
    src_landmarks = landmarks_frames.iloc[index, 1:].values
    src_landmarks = src_landmarks.astype('float').reshape(-1, 2)
    print('Image name: {}'.format(src_name))
    print('Landmarks shape: {}'.format(src_landmarks.shape))
    print('First 4 Landmarks:\n{}'.format(src_landmarks[:4]))
    plt.figure()
    show_landmarks(io.imread(os.path.join('data/faces/', src_name)), src_landmarks)  # 调用显示函数
    plt.show()

    print('########### 数据展示 ##########')
    face_dataset = FaceLandmarksDataset(csv_file='data/faces/face_landmarks.csv', root_dir='data/faces/')
    plt.figure()
    for i in range(len(face_dataset)):
        data_sample = face_dataset[i]
        print(i, data_sample['image'].shape, data_sample['landmarks'].shape)
        ax = plt.subplot(1, 4, i + 1)
        plt.tight_layout()
        ax.set_title('Sample #{}'.format(i))  # set title
        ax.axis('off')
        show_landmarks(**data_sample)
        # show_landmarks(data_sample['image'], data_sample['landmarks'])
        if i == 3:
            plt.show()
            break

    print('########### 数据转换 ##########')
    scale = Rescale(256)
    crop = RandomCrop(128)
    composed = transforms.Compose([Rescale(256), RandomCrop(224)])  # 组合转换
    fig = plt.figure()
    sample = face_dataset[65]
    for i, tsf in enumerate([scale, crop, composed]):
        transformed_sample = tsf(sample)
        ax = plt.subplot(1, 3, i + 1)
        plt.tight_layout()
        ax.set_title(type(tsf).__name__)  # set title
        show_landmarks(**transformed_sample)
    plt.show()

    print('########### 迭代数据集 ##########')
    transformed_dataset = FaceLandmarksDataset(
        csv_file='data/faces/face_landmarks.csv', root_dir='data/faces/',
        transform=transforms.Compose([Rescale(256), RandomCrop(224), ToTensor()]))
    for i in range(len(transformed_dataset)):
        sample = transformed_dataset[i]
        print(i, sample['image'].size(), sample['landmarks'].size())
        if i == 3:
            break

    print('########### 批次迭代数据集 ##########')
    data_loader = DataLoader(transformed_dataset, batch_size=4, shuffle=True, num_workers=4)
    for i_batch, sample_batched in enumerate(data_loader):
        print(i_batch, sample_batched['image'].size(), sample_batched['landmarks'].size())
        # 观察第4批次并停止
        if i_batch == 3:
            plt.figure()
            show_landmarks_batch(sample_batched)
            plt.axis('off')
            plt.ioff()
            plt.show()
            break

    print('########### 利用torchvision创建数据加载器 ##########')
    data_transform = transforms.Compose([
        transforms.RandomSizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
    # ants_bees_dataset = datasets.ImageFolder(root='ants_bees_data/train', transform=data_transform)
    # dataset_loader = DataLoader(ants_bees_dataset, batch_size=4, shuffle=True, num_workers=4)

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值