PyTorch官方教程 - Getting Started - 数据加载和处理

DATA LOADING AND PROCESSING TUTORIAL

  • scikit-image
  • pandas
import os
import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils

%matplotlib inline
# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

plt.ion()   # interactive mode
landmarks_frame = pd.read_csv("../../datasets/faces/face_landmarks.csv")
print(landmarks_frame.sample(10))

n = 65
img_name = landmarks_frame.iloc[n, 0]

landmarks = landmarks_frame.iloc[n, 1 :].as_matrix()
landmarks = landmarks.astype(np.float).reshape(-1, 2)

print("Image name: {}".format(img_name))
print("Landmarks shape: {}".format(landmarks.shape))
print("First 4 Landmarks: {}".format(landmarks[: 4]))

                   image_name  part_0_x  part_0_y  part_1_x  part_1_y  \
24  2722779845_7fcb64a096.jpg        71       128        63       151   
14  2322901504_08122b01ba.jpg       195       202       191       219   
20   262007783_943bbcf613.jpg       123       137       128       176   
3    110276240_bec305da91.jpg        42       140        45       161   
30   299733036_fff5ea6f8e.jpg       122       149       121       163   
18                2382SJ8.jpg        50       100        49       112   
51  3718903026_c1bf5dfcf8.jpg       125       220       124       245   
35  3273658251_b95f65c244.jpg       220        85       221       110   
62           britney-bald.jpg        52       134        54       149   
41   348272697_832ce65324.jpg       127       222       127       246   

    part_2_x  part_2_y  part_3_x  part_3_y  part_4_x  ...  part_63_x  \
24        58       173        56       196        57  ...        135   
14       187       236       191       254       198  ...        234   
20       137       215       150       250       171  ...        311   
3         51       180        61       200        73  ...        144   
30       119       177       119       192       124  ...        186   
18        50       124        52       137        56  ...        104   
51       123       269       125       295       134  ...        232   
35       224       135       229       159       237  ...        326   
62        56       164        60       179        65  ...        110   
41       128       269       130       292       136  ...        241   

    part_63_y  part_64_x  part_64_y  part_65_x  part_65_y  part_66_x  \
24        226        156        240        133        228        124   
14        269        273        272        235        279        226   
20        256        316        252        312        254        302   
3         197        180        189        147        204        136   
30        207        188        211        186        207        182   
18        137        118        144        104        147         96   
51        300        263        300        232        323        215   
35        175        343        177        326        174        316   
62        193        126        193        110        199        104   
41        286        258        294        240        287        231   

    part_66_y  part_67_x  part_67_y  
24        225        116        221  
14        277        219        274  
20        258        292        258  
3         208        125        209  
30        206        178        204  
18        147         89        146  
51        324        199        320  
35        175        306        174  
62        200         98        198  
41        287        222        285  

[10 rows x 137 columns]
Image name: person-7.jpg
Landmarks shape: (68, 2)
First 4 Landmarks: [[32. 65.]
 [33. 76.]
 [34. 86.]
 [34. 97.]]
def show_landmarks(image, landmarks):
    "Show image with landmarks"
    plt.imshow(image)
    plt.scatter(landmarks[:, 0], landmarks[:, 1], s=10, marker=".", c="r")
    plt.pause(0.001)  # pause a bit so that plots are updated
    
plt.figure()
show_landmarks(io.imread(os.path.join("../../datasets/faces/", img_name)), landmarks)
plt.show()

在这里插入图片描述

Dataset类

torch.utils.data.Dataset为表示数据集的抽象类。Dataset类的子类需要重载:

  • __len__:len(dataset)返回数据集尺寸
  • __getitem__:支持索引操作dataset[i]
class FaceLandmarksDataset(Dataset):
    
    def __init__(self, csv_file, root_dir, transform=None):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.landmarks_frame = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform
        
    def __len__(self):
        return len(self.landmarks_frame)
    
    def __getitem__(self, idx):
        img_name = os.path.join(self.root_dir,
                                self.landmarks_frame.iloc[idx, 0])
        image = io.imread(img_name)
        landmarks = self.landmarks_frame.iloc[idx, 1 :].as_matrix()
        landmarks = landmarks.astype(np.float).reshape(-1, 2)
        sample = {"image": image, "landmarks": landmarks}
        
        if self.transform:
            sample = self.transform(sample)
            
        return sample
    

face_dataset = FaceLandmarksDataset(csv_file="../../datasets/faces/face_landmarks.csv",
                                    root_dir="../../datasets/faces/")

fig = plt.figure()

for i in range(len(face_dataset)):
    sample = face_dataset[i]
    
    print(i, sample["image"].shape, sample["landmarks"].shape)
    
    ax = plt.subplot(1, 4, i + 1)
    plt.tight_layout()
    ax.set_title("Sample #{}".format(i))
    ax.axis("off")
    show_landmarks(**sample)
    
    if i == 3:
        plt.show()
        break
0 (324, 215, 3) (68, 2)

在这里插入图片描述

1 (500, 333, 3) (68, 2)

在这里插入图片描述

2 (250, 258, 3) (68, 2)

在这里插入图片描述

3 (434, 290, 3) (68, 2)

在这里插入图片描述

Transforms

  • Rescale:缩放
  • RandomCrop:随机裁剪
  • ToTensor:numpy图像转torch图像

实现可调用类:__call__()

class Rescale(object):
    
    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        self.output_size = output_size
        
    def __call__(self, sample):
        image, landmarks = sample["image"], sample["landmarks"]
        
        h, w = image.shape[: 2]
        if isinstance(self.output_size, int):
            if h > w:
                new_h, new_w = self.output_size * h / w, self.output_size
            else:
                new_h, new_w = self.output_size, self.output_size * w / h
        else:
            new_h, new_w = self.output_size
            
        new_h, new_w = int(new_h), int(new_w)
        
        img = transform.resize(image=image, output_shape=(new_h, new_w))
        
        # h and w are swapped for landmarks because for images,
        # x and y axes are axis 1 and 0 respectively
        landmarks = landmarks * [new_w / w, new_h / h]
        
        return {"image": img, "landmarks": landmarks}
    
    
class RandomCrop(object):
    """Crop randomly the image in a sample.

    Args:
        output_size (tuple or int): Desired output size. If int, square crop
            is made.
    """
    
    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        if isinstance(output_size, int):
            self.output_size = (output_size, output_size)
        else:
            assert len(output_size) == 2
            self.output_size = output_size
        
    def __call__(self, sample):
        image, landmarks = sample["image"], sample["landmarks"]
        
        h, w = image.shape[: 2]
        new_h, new_w = self.output_size
        
        top = np.random.randint(0, h - new_h)
        left = np.random.randint(0, w - new_w)
        
        image = image[top : top + new_h, left : left + new_w]
        
        landmarks = landmarks - [left, top]
        
        return {"image": image, "landmarks": landmarks}
    
    
class ToTensor(object):
    """Convert ndarrays in sample to Tensors."""
    
    def __call__(self, sample):
        image, landmarks = sample["image"], sample["landmarks"]
        
        # swap color axis because
        # numpy image: H x W x C
        # torch image: C X H X W
        image = image.transpose((2, 0, 1))
        
        return {"image": torch.from_numpy(image),
                "landmarks": torch.from_numpy(landmarks)}
    

组合变换

  • torchvision.transforms.Compose
scale = Rescale(256)
crop = RandomCrop(128)
composed = transforms.Compose([Rescale(256), RandomCrop(224)])

# Apply each of the above transforms on sample.
fig = plt.figure()
sample = face_dataset[65]
for i, tsfrm in enumerate([scale, crop, composed]):
    transformed_sample = tsfrm(sample)
    
    ax = plt.subplot(1, 3, i + 1)
    plt.tight_layout()
    ax.set_title(type(tsfrm).__name__)
    show_landmarks(**transformed_sample)
    
plt.show()

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

遍历数据集

transformed_dataset = FaceLandmarksDataset(csv_file="../../datasets/faces/face_landmarks.csv",
                                           root_dir="../../datasets/faces/",
                                           transform=transforms.Compose([
                                               Rescale(256),
                                               RandomCrop(224),
                                               ToTensor()
                                           ]))

for i in range(len(transformed_dataset)):
    sample = transformed_dataset[i]
    
    print(i, sample["image"].size(), sample["landmarks"].size())
    
    if i == 3:
        break
0 torch.Size([3, 224, 224]) torch.Size([68, 2])
1 torch.Size([3, 224, 224]) torch.Size([68, 2])
2 torch.Size([3, 224, 224]) torch.Size([68, 2])
3 torch.Size([3, 224, 224]) torch.Size([68, 2])
  • torch.utils.data.DataLoader

    • Batching the data
    • Shuffling the data
    • Load the data in parallel using multiprocessing workers
dataloader = DataLoader(dataset=transformed_dataset,
                        batch_size=4,
                        shuffle=True,
                        num_workers=0)

# Helper function to show a batch
def show_landmarks_batch(sample_batched):
    """Show image with landmarks for a batch of samples."""
    images_batch, landmarks_batch = \
        sample_batched["image"], sample_batched["landmarks"]
    batch_size = len(images_batch)
    im_size = images_batch.size(2)
    grid_border_size = 2
    
    grid = utils.make_grid(images_batch)
    plt.imshow(grid.numpy().transpose((1, 2, 0)))
    
    for i in range(batch_size):
        plt.scatter(landmarks_batch[i, :, 0].numpy() + i * im_size + (i + 1) * grid_border_size,
                    landmarks_batch[i, :, 1].numpy() + grid_border_size,
                    s=10, marker=".", c="r")
        
    plt.title("Batch from dataloader")
    
for i_batch, sample_batched in enumerate(dataloader):
    print(i_batch, sample_batched["image"].size(),
          sample_batched["landmarks"].size())
    
    # observe 4th batch and stop.
    if i_batch == 3:
        plt.figure()
        show_landmarks_batch(sample_batched)
        plt.axis("off")
        plt.ioff()
        plt.show()
        break
0 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2])
1 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2])
2 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2])
3 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2])

在这里插入图片描述

Afterword: torchvision

  • ImageFolder
root/ants/xxx.png
root/ants/xxy.jpeg
root/ants/xxz.png
.
.
.
root/bees/123.jpg
root/bees/nsdf3.png
root/bees/asd932_.png
import torch
from torchvision import transforms, datasets

data_transform = transforms.Compose([
    transforms.RandomSizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

hymenoptera_dataset = datasets.ImageFolder(root="../../datasets/hymenoptera_data/train",
                                           transform=data_transform)
dataset_loader = torch.utils.data.DataLoader(dataset=hymenoptera_dataset,
                                             batch_size=4,
                                             shuffle=True,
                                             num_workers=4)

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值