【用法总结】利用Pytorch进行数据加载和预处理的实现思路

整理一下利用Pytorch进行数据加载和预处理的实现思路:
主要分以下三种情况:

1 对于torchvision提供的数据集

  • 这是最简单的一种情况。
  • 对于这一类数据集,就是PyTorch已经帮我们做好了所有的事情,连数据源都不需要自己下载。
  • Imagenet,CIFAR10,MNIST等等PyTorch都提供了数据加载的功能,所以可以先看看你要用的数据集是不是这种情况。
    import torch
    import torchvision
    import torchvision.transforms as transforms
    
    transform = transforms.Compose(
        [transforms.ToTensor(), # 归一化到(0,1),直接除以255 
         transforms.Normalize(std=(0.5, 0.5, 0.5), mean=(0.5, 0.5, 0.5))# 归一化到(-1,1),channel=(channel-mean)/std
        ]
    )
    train_set = torchvision.datasets.CIFAR10(root='./data', train=True, download=False, transform=transform)
    train_loader = torch.utils.data.DataLoader(dataset=train_set, batch_size=4, shuffle=True, num_workers=2)
    
    test_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=False, transform=transform)
    test_loader = torch.utils.data.DataLoader(dataset=test_set, batch_size=4, shuffle=False, num_workers=2)

     

2 对于特定结构的数据集

  • 这种情况就是不在上述PyTorch提供数据库之列,但是满足下面的形式:
     root/ants/xxx.png
     root/ants/xxy.jpeg
     root/ants/xxz.png
    .
    .
    .
    root/bees/123.jpg
    root/bees/nsdf3.png
    root/bees/asd932_.png
    
  • 那么就可以通过torchvision中的通用数据集ImageFolder来完成加载。
  • 具体使用方法:
    import torch
    from torchvision import transforms, datasets
    
    data_transform = transforms.Compose([
            transforms.RandomSizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
    hymenoptera_dataset = datasets.ImageFolder(root='hymenoptera_data/train',
                                               transform=data_transform)
    dataset_loader = torch.utils.data.DataLoader(hymenoptera_dataset,
                                                 batch_size=4, shuffle=True,
                                                 num_workers=4)

     

3 对于最普通的数据集

  • 最后一种情况是既不是自带数据集,又不满足ImageFolder,这种时候就自己进行处理。
  • 首先,定义数据集的类(myDataset),这个类要继承dataset这个抽象类,并实现__len__以及__getitem__这两个函数,通常情况还包括初始函数__init__.
  • 然后,实现用于特定图像预处理的功能,并封装成类。当然常用的一些变换可以在torchvision中找到。用torchvision.transforms.Compose将它们进行组合成(transform)
  • transform作为上面myDataset类的参数传入,并得到实例化myDataset得到(transformed_dataset)对象。
  • 最后,将transformed_dataset作为torch.utils.data.DataLoader类的形参,并根据需求设置自己是否需要打乱顺序,批大小...
  • 具体见:https://pytorch.org/tutorials/beginner/data_loading_tutorial.html
from __future__ import print_function, division
import os
import torch
import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils

# Ignore warning
import warnings
warnings.filterwarnings("ignore")

plt.ion()  # interactive mode

######################################################################
# 数据读取
landmarks_frame = pd.read_csv('./data/faces/face_landmarks.csv')
# landmarks_frame.info()

n = 65
img_name = landmarks_frame.iloc[n, 0]    # 索引得到第n行、第0列(照片名)
landmarks = landmarks_frame.iloc[n, 1:].as_matrix()   # 索引得到第n行、第1~137列(annotation的横纵坐标)
landmarks = landmarks.astype('float').reshape(-1, 2)  # reshape为(68,2)的形状,即第一列为散点横坐标,第二列为纵坐标

print('Image name: {}'.format(img_name))              # 查看第n张的照片名
print('Landmarks shape: {}'.format(landmarks.shape))
print('First 4 Landmarks: {}'.format(landmarks[:4]))  # 查看前四个点的(x,y)


def show_landmarks(image, landmarks):
    # show image with landmarks
    plt.imshow(image)
    plt.scatter(landmarks[:, 0], landmarks[:, 1], s=10, marker='.', c='r')
    # plt.pause(0.001)  # pause a bit so that plots are updated
    plt.pause(3)


plt.figure()
img = io.imread(os.path.join('./data/faces/', img_name))
show_landmarks(img, landmarks)
plt.show()

######################################################################
# 定义人脸标记数据集
#
class FaceLandmarksDataset(Dataset): # 继承Dataset

    def __init__(self, csv_file, root_dir, transform=None):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.landmarks_frame = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self): # __len__返回数据集的大小,用法:len(dataset)
        return len(self.landmarks_frame)

    def __getitem__(self, idx):# 支持整数idx索引,范围从0到len(self),用法:dataset[i]得到索引为i的样本及标签
        img_name = os.path.join(self.root_dir, self.landmarks_frame.iloc[idx, 0])
        image = io.imread(img_name)
        landmarks = self.landmarks_frame.iloc[idx, 1:].as_matrix()
        landmarks = landmarks.astype('float').reshape(-1, 2)
        sample = {'image': image, 'landmarks': landmarks} # 返回字典形式dict

        if self.transform:
            sample = self.transform(sample)  # 可以实现裁剪缩放等数据转换(transform类是有__call__方法的)
                                             # 所以就可以利用函数形式transform(sample)来进行变换
        return sample

######################################################################
# 实例化人脸数据集类、并show出来
face_dataset = FaceLandmarksDataset(csv_file='./data/faces/face_landmarks.csv', root_dir='./data/faces/')

fig = plt.figure()

for i in range(len(face_dataset)):
    sample = face_dataset[i] # 实例化第i个样本的image和landmark: 因为有__getitem__ 方法,所以可以根据索引得到样本的字典形式
    print(i, sample['image'].shape, sample['landmarks'].shape)
    ax = plt.subplot(1, 4, i+1)
    plt.tight_layout()
    ax.set_title('Sample #{}'.format(i))
    ax.axis('off')
    show_landmarks(**sample) #通常的话本函数需要传入两个参数image, landmarks,但使用此方法可以得到字典中所有键对应的值

    if i == 3:   # 每四张显示一个figure
        plt.show()
        break

######################################################################
# 三个transform类(预处理方法)的具体实现:
# 上面返回的图都是原始图像,大小不一,所以一般来说不会直接输入到卷积网络。
# 上面我们在实现自己的dataset类时,可以传入参数transform, 下面我们看一看如何实现transform,并传入到dataset

class Rescale(object):
    """Rescale the image in a sample to a given size.
    Args:
        output_size (tuple or int): Desired output size. If tuple, output is matched to output_size.
        If int, smaller of image edges is matched to output_size keeping aspect ratio the same.
    """
    def __init__(self, output_size):  # 传入的参数为图像输出大小
        assert isinstance(output_size, (int, tuple))  # 断言:如果不match,就抛出异常
        if isinstance(output_size, int):              # 如果为int,例如256,则返回(256,256)大小的图
            self.output_size = (output_size, output_size)
        else:
            assert len(output_size) == 2              # 如果为tuple,例如(211,985),则返回(211,985)大小的图
            self.output_size = output_size

    def __call__(self, sample):
        image, landmarks = sample['image'], sample['landmarks']

        h, w = image.shape[:2]  # image.shape (h,w,c)
        if isinstance(self.output_size, int):  # 当输出size为int时,将此值作为图像的最短边长,而长边则需根据比例进行缩放
            if h > w:                          # h>w,缩放h
                new_h, new_w = self.output_size * h/w, self.output_size
            else:                              # h<=w,缩放w
                new_h, new_w = self.output_size, self.output_size * w/h
        else:                                  # 当输出size为tuple时,直接将此tuple作为图像输出尺寸
            new_h, new_w = self.output_size

        new_h, new_w = int(new_h), int(new_w)

        img = transform.resize(image, (new_h, new_w))
        # h and w are swapped for landmarks because for images,
        # x and y axes are axis 1 and 0 respectively
        landmarks = landmarks * [new_w/w, new_h/h]

        return {'image': img, 'landmarks': landmarks}  # 注意__getitem__返回的是字典,所以这里也要返回字典

class RandomCrop(object):
    """Crop randomly the image in a sample.
       Args:
           output_size (tuple or int): Desired output size. If int, square crop is made.
    """
    def __init__(self, output_size): # 传入的参数为图像输出大小
        assert isinstance(output_size, (int, tuple))      # 断言:如果不match,就抛出异常
        if isinstance(output_size, int):                  # 如果为int,例如256,则返回(256,256)大小的图
            self.output_size = (output_size, output_size)
        else:
            assert len(output_size) == 2                  # 如果为tuple,例如(211,985),则返回(211,985)大小的图
            self.output_size = output_size

    def __call__(self, sample):
        image, landmarks = sample['image'], sample['landmarks']

        h, w = image.shape[:2] # image.shape (h,w,c)
        new_h, new_w =self.output_size

        top = np.random.randint(0, h-new_h)
        left = np.random.randint(0, w-new_w)

        image = image[top:top+new_h, left:left+new_w]
        landmarks = landmarks-[left, top]

        return {'image': image,'landmarks':landmarks}

class ToTensor(object): #第三个类转numpy为tensor

    def __call__(self, sample):
        image, landmarks = sample['image'], sample['landmarks']

        # swap color axis because
        # numpy image: H x W x C
        # torch image: C X H X W
        image = image.transpose((2, 0, 1)) # 转换维度、按照torch格式来
        return {'image': torch.from_numpy(image), 'landmarks': torch.from_numpy(landmarks)}

######################################################################
# Apply each of the above transforms on sample.

scale = Rescale(256)    # 实例化第一个类,此时该对象可当做函数使用
crop = RandomCrop(128)  # 实例化第二个类,此时该对象可当做函数使用
composed = transforms.Compose([Rescale(256), RandomCrop(224)])

fig = plt.figure()
sample = face_dataset[65]
for i, tsfrm in enumerate([scale, crop, composed]):  # 试着分别使用这三个函数
    transformed_sample = tsfrm(sample)               # sample作为参数传入了函数里面,返回image、landmark字典

    ax = plt.subplot(1, 3, i+1)
    plt.tight_layout()
    ax.set_title(type(tsfrm).__name__)
    show_landmarks(**transformed_sample)

plt.show()

######################################################################
# Iterating through the dataset
# 根据上文实现的transform,现在我们可以将其放到我们定制的dataset类里面。
# 每当我们的dataset被采样时便会读取一张图片、接着进行transform:
transformed_dataset = FaceLandmarksDataset(
    csv_file='./data/faces/face_landmarks.csv',
    root_dir='./data/faces/',
    transform=transforms.Compose([Rescale(256), RandomCrop(224), ToTensor()])
) # 实例化定制我们自定义的dataset

for i in range(len(transformed_dataset)):
    sample = transformed_dataset[i]      #每次采样一张图片,其索引为i

    print(i, sample['image'].size(), sample['landmarks'].size())
    if i == 3:
        break

######################################################################
# 在自定义数据集迭代时:上面依靠for循环,每次才能索引一张图,效率低下
# 那么我们需要batch批量数据读入、shuffle打散数据、multiprocessing并行处理时,怎么办?

dataloader = DataLoader(dataset=transformed_dataset, batch_size=4, shuffle=True, num_workers=4)

# Helper function to show a batch
def show_landmarks_batch(sample_batched): # 传进来参数为一个sample对象,自带image和landmarks字典形式
    """Show image with landmarks for a batch of samples."""
    images_batch, landmarks_batch = sample_batched['image'],  sample_batched['landmarks']
    batch_size = len(images_batch)
    im_size = images_batch.size(2)  # why 2 ?

    grid = utils.make_grid(images_batch)
    plt.imshow(grid.numpy().transpose((1, 2, 0)))

    for i in range(batch_size):
        plt.scatter(
            landmarks_batch[i, :, 0].numpy() + i*im_size,
            landmarks_batch[i, :, 1].numpy(),
            s = 10, marker='.', c = 'r'
        )
######################################################################
#
for i_batch, sample_batched in enumerate(dataloader): # i_batch 可以看做step
    print(i_batch, sample_batched['image'].size(), sample_batched['landmarks'].size())

    if i_batch == 3: #每个batchdo
        plt.figure()
        show_landmarks_batch(sample_batched)
        plt.axis('off')
        plt.ioff()
        plt.show()
        break

 

  • 6
    点赞
  • 15
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值