pytorch入门(二):数据加载和处理

本章对应pytorch官方文档链接

小引

本篇主要介绍了如何利用 pytorch加载和处理数据集,并以图像数据集为例讲解了几种图像预处理的方法。

数据加载

引包

from __future__ import print_function, division
import os
import torch
import pandas as pd  #更加方便的处理csv
from skimage import io, transform  #用于图像io和变换
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils

# Ignore warnings
#import warnings
#warnings.filterwarnings("ignore")

plt.ion()   # interactive mode

数据集

数据集
数据集共包含69张图片,每个样本包含图片名信息和68个界标点坐标。
数据集格式如下:

数据集格式:
image_name,part_0_x,part_0_y,part_1_x,part_1_y,part_2_x, ... ,part_67_x,part_67_y
0805personali01.jpg,27,83,27,98, ... 84,134
1084239450_e76e00b7e7.jpg,70,236,71,257, ... ,128,312

编写辅助函数

#读取数据集
landmarks_frame = pd.read_csv('data/faces/face_landmarks.csv')   #根据需要修改

n = 3
img_name = landmarks_frame.iloc[n, 0]
#对数据进行数组转化 iloc:通过行号来取行数据
landmarks = landmarks_frame.iloc[n, 1:].as_matrix()  #每张数据集具有136个坐标点
#landmarks = landmarks_frame.iloc[n, 1:].values # 二者等价
landmarks = landmarks.astype('float').reshape(-1, 2)  #转化为68行2列

print('Image name: {}'.format(img_name))
print('Landmarks shape: {}'.format(landmarks.shape))
print('First 4 Landmarks: {}'.format(landmarks[:4]))

注意:在使用 landmarks_frame.iloc[n, 1:].as_matrix() 方法时编译器可能会发出警告,显示方法即将过时,推荐使用landmarks_frame.iloc[n, 1:].values方法(values后没有括号!

输出如下信息:

Image name: person-7.jpg
Landmarks shape: (68, 2)
First 4 Landmarks: [[32. 65.]
 [33. 76.]
 [34. 86.]
 [34. 97.]]

显示图像及其特征点

#显示图像及特征点
#plt.imshow()函数负责对图像进行处理,并显示其格式,但是不能显示。其后跟着plt.show()才能显示出来。
def show_landmarks(image, landmarks):
    plt.imshow(image)
    plt.scatter(landmarks[:, 0], landmarks[:, 1], s=10, marker= '.',
                c = 'r')
    plt.pause(0.001)  #停顿使图片更新出来

plt.figure()
show_landmarks(io.imread(os.path.join('data/faces/', img_name)), landmarks)
plt.show()

输出结果:
在这里插入图片描述

定义数据集类

包含自定义函数:

__ len __,使得len(dataset)返回数据集的大小。

__ getitem __,支持索引,使得dataset[i]可以用来获取第i个样本。

以上两种方法的目的是为了提高内存效率,图形不是立即存储在内存中,而是按照需要读取。数据集中的样本是字典:{‘image’: image, ‘landmarks’: landmarks}

#自定义数据集类,继承数据集的抽象类
class FaceLandmarksDataset(Dataset):
    def __init__(self, csv_file, root_dir, transform = None):
        self.landmarks_frame = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform  #图形的预处理,默认参数为None
        
    def __len__(self):
        return len(self.landmarks_frame)
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        img_name = os.path.join(self.root_dir, self.landmarks_frame.iloc[idx, 0])
        image = io.imread(img_name)
        landmarks = self.landmarks_frame.iloc[idx, 1:].as_matrix()
        landmarks = landmarks.astype(dtype = "float").reshape(-1, 2)
        sample = {"image":image, "landmarks":landmarks}  #返回值样本为字典类型
        
        if self.transform:
            sample = self.transform(sample)  # 可以实现裁剪缩放等数据转换(transform类是有__call__方法的)
        
        return sample

实例化类并遍历,打印前四个样本:

face_dataset = FaceLandmarksDataset(csv_file="data/faces/face_landmarks.csv", root_dir="data/faces/")

fig = plt.figure()

for i in range(len(face_dataset)):
    sample = face_dataset[i]  #__getitem__ 获取第i个样本,返回字典
    
    print(i, sample['image'].shape, sample['landmarks'].shape)
    
    ax = plt.subplot(1, 4, i+1)
    plt.tight_layout()
    ax.set_title('Sample #{}'.format(i))
    ax.axis("off")
    show_landmarks(**sample) # 因为sample为字典,所以可以利用这种形式返回字典中所有键对应的值
    if i == 3:              
        plt.show()
        break

结果如下:
在这里插入图片描述

0 (324, 215, 3) (68, 2)
1 (500, 333, 3) (68, 2)
2 (250, 258, 3) (68, 2)
3 (434, 290, 3) (68, 2)

数据处理

isinstance()的用法
isinstance() 函数来判断一个对象是否是一个已知的类型,类似 type()。
isinstance() 与 type() 区别:
type() 不会认为子类是一种父类类型,不考虑继承关系。
isinstance() 会认为子类是一种父类类型,考虑继承关系。

enumerate() 函数用于将一个可遍历的数据对象(如列表、元组或字符串)组合为一个索引序列,同时列出数据和数据下标,一般用在 for 循环当中。

定义三种转换方式:Rescale图像缩放,RandomCrop随机裁剪和Totensor将numpy转tensor

#三种转换方式:返回样本Sample
class Rescale(object):     # 第一个类规范图像尺寸
    """Rescale the image in a sample to a given size.

    Args:
        output_size (tuple or int): Desired output size. If tuple, output is
            matched to output_size. If int, smaller of image edges is matched
            to output_size keeping aspect ratio the same.
    """

    def __init__(self, output_size):     # 此类需传入的参数为图像输出大小
        assert isinstance(output_size, (int, tuple))   # 这个size可以为int例如256,也可以为tuple,例如(256,256)
        self.output_size = output_size

    def __call__(self, sample):    #magic method:使类像函数一样可以调用
        image, landmarks = sample['image'], sample['landmarks']

        h, w = image.shape[:2]
        if isinstance(self.output_size, int):     # 当输出size为int时,将此值作为图像的最短边长,而长边则需根据比例进行缩放
            if h > w:
                new_h, new_w = self.output_size * h / w, self.output_size
            else:
                new_h, new_w = self.output_size, self.output_size * w / h 
        else:                                     # 当输出为tuple时,直接将此tuple作为图像输出尺寸
            new_h, new_w = self.output_size

        new_h, new_w = int(new_h), int(new_w)

        img = transform.resize(image, (new_h, new_w))

        # h and w are swapped for landmarks because for images,
        # x and y axes are axis 1 and 0 respectively
        landmarks = landmarks * [new_w / w, new_h / h]  

        return {'image': img, 'landmarks': landmarks}    # 注意__getitem__返回的是字典,所以这里也要返回字典


class RandomCrop(object):      #第二个类随机裁剪
    """Crop randomly the image in a sample.

    Args:
        output_size (tuple or int): Desired output size. If int, square crop
            is made.
    """

    def __init__(self, output_size):                    # 此类需传入输出尺寸
        assert isinstance(output_size, (int, tuple))
        if isinstance(output_size, int):                # 如果为int例如256则返回任意(256,256)大小的图
            self.output_size = (output_size, output_size)
        else:
            assert len(output_size) == 2                # 如果为tuple例如(211,985),则返回(211,985)大小的图
            self.output_size = output_size    

    def __call__(self, sample):  
        image, landmarks = sample['image'], sample['landmarks']

        h, w = image.shape[:2]
        new_h, new_w = self.output_size

        bottom = np.random.randint(0, h - new_h)  #裁剪边框底边
        left = np.random.randint(0, w - new_w)  #裁剪边框左边

        image = image[bottom: bottom + new_h,
                      left: left + new_w]

        landmarks = landmarks - [left, bottom] 

        return {'image': image, 'landmarks': landmarks}


class ToTensor(object):         # 第三个类转numpy为tensor
    """Convert ndarrays in sample to Tensors."""

    def __call__(self, sample):      # 无需init方法,直接将此类作为函数
        image, landmarks = sample['image'], sample['landmarks']

        # swap color axis because
        # numpy image: H x W x C
        # torch image: C X H X W
        image = image.transpose((2, 0, 1))     # 转换维度,按照torch格式来
        return {'image': torch.from_numpy(image),
                'landmarks': torch.from_numpy(landmarks)}

组合变换

利用torchvision.transforms.Compose,执行transform的组合变换。

scale = Rescale(256)      # 实例化第一个类,此时该对象可当做函数使用
crop = RandomCrop(128)    # 实例化第二个类,此时该对象可当做函数使用
composed = transforms.Compose([Rescale(256),    # 结合两个方法
                               RandomCrop(224)])

# Apply each of the above transforms on sample.
fig = plt.figure()
sample = face_dataset[1]
for i, tsfrm in enumerate([scale, crop, composed]):    # 试着分别使用这三个函数
    transformed_sample = tsfrm(sample)                 # sample作为参数传入了函数里面,返回image、landmark字典

    ax = plt.subplot(1, 3, i + 1)
    plt.tight_layout()
    ax.set_title(type(tsfrm).__name__)
    show_landmarks(**transformed_sample)               # 调用之前的函数进行显示

plt.show()

结果如下:
在这里插入图片描述

遍历数据集

方法一:for循环
创建具有组合变换的数据集,从文件中读取图像、变换,循环遍历。

# 实例化我们定制的dataset!已经完成相应的预处理
transformed_dataset = FaceLandmarksDataset(csv_file='data/faces/face_landmarks.csv',
                                           root_dir='data/faces/',
                                           transform=transforms.Compose([
                                               Rescale(256),
                                               RandomCrop(224),
                                               ToTensor()
                                           ]))                         

for i in range(len(transformed_dataset)):
    sample = transformed_dataset[i]                       # for循环, 每次采样索引为i的一张图片

    print(i, sample['image'].size(), sample['landmarks'].size())

    if i == 3:     # 查看4张图就好
        break

缺点:
1.没有批量处理数据
2.没有整理数据
3.没有使用多进程multiprocessing并行加载数据。

dataloader = DataLoader(transformed_dataset, batch_size=4,        # batch为4张,打散,进程数为4
                        shuffle=True, num_workers=4)     
print(len(dataloader))

# Helper function to show a batch
def show_landmarks_batch(sample_batched):       # 显示一个batch数据的函数,主要利用工具函数make_grid
    """Show image with landmarks for a batch of samples."""
    images_batch, landmarks_batch = \
            sample_batched['image'], sample_batched['landmarks']
    batch_size = len(images_batch)
    im_size = images_batch.size(2)

    grid = utils.make_grid(images_batch)                # 其输入为FLoatTensor
    plt.imshow(grid.numpy().transpose((1, 2, 0)))       # 只有当画图的时候才转为numpy并转换维度
  
    for i in range(batch_size):    #landmarks[69, 68, 2] 69张图片,每张图片68个点,每个点x\y坐标。
        plt.scatter(landmarks_batch[i, :, 0].numpy() + i * im_size,
                    landmarks_batch[i, :, 1].numpy(),
                    s=10, marker='.', c='r')

        plt.title('Batch from dataloader')
print("ok")
for i_batch, sample_batched in enumerate(dataloader):
    print(i_batch, sample_batched['image'].size(),
          sample_batched['landmarks'].size())

    # observe 4th batch and stop.
    if i_batch == 3:                         # 只打印第4个batch
        plt.figure()
        show_landmarks_batch(sample_batched)
        plt.axis('off')
        plt.ioff()
        plt.show()
        break

结果:
在这里插入图片描述

0 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2])
1 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2])
2 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2])
3 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2])

其他注意事项

注意:官方文档代码在win10环境下运行可能会报错:BrokenPipeError: [Errno 32] Broken pipe

解决方法:不要进行多进程处理,num_workers=0,进程数改为0。
参考链接

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值