Pytorch入门学习(六)--- 加载数据以及预处理(初步)--- 只为了更好理解流程

本文介绍了PyTorch中加载和预处理数据的基础知识,包括通过函数读取图片、自定义Dataset、图像变换、使用torchvision.transforms.Compose组合变换以及利用DataLoader进行批量、shuffle和并行加载数据。内容详细展示了数据处理的完整流程。
摘要由CSDN通过智能技术生成

直接从Pytorch Tutorials拿过来,看看。

需要的包:
1. scikit-image: 图像io以及变形
2. pandas: 读入csv文件
数据:
faces
csv的数据形式:
总共68个人脸关键点。

     image_name,part_0_x,part_0_y,part_1_x,part_1_y,part_2_x, ... ,part_67_x,part_67_y
     0805personali01.jpg,27,83,27,98, ... 84,134
     1084239450_e76e00b7e7.jpg,70,236,71,257, ... ,128,312

或是如图:
这里写图片描述

最简单的通过函数读取图片

# -*- coding: utf-8 -*-
"""
Data Loading and Processing Tutorial
====================================
**Author**: `Sasank Chilamkurthy <https://chsasank.github.io>`_
"""

from __future__ import print_function, division
import os
import torch
import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils

# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

plt.ion()   # interactive mode


landmarks_frame = pd.read_csv('faces/face_landmarks.csv')

n = 65
img_name = landmarks_frame.ix[n, 0]
landmarks = landmarks_frame.ix[n, 1:].as_matrix().astype('float')
landmarks = landmarks.reshape(-1, 2)

print('Image name: {}'.format(img_name))
print('Landmarks shape: {}'.format(landmarks.shape))
print('First 4 Landmarks: {}'.format(landmarks[:4]))

# 定义show_landmarks.
def show_landmarks(image, landmarks):
    """Show image with landmarks"""
    plt.imshow(image)
    plt.scatter(landmarks[:, 0], landmarks[:, 1], s=10, marker='.', c='r')
    plt.pause(0.001)  # pause a bit so that plots are updated

plt.figure()
# 用 io.imread来读取图片
show_landmarks(io.imread(os.path.join('faces/', img_name)),
               landmarks)
plt.show()

通过继承Dataset

# 自定义数据集时,要继承 Dataset类。
# 一般至少要有 __init__, __len__, __getitem__
class FaceLandmarksDataset(Dataset):
    """Face Landmarks dataset."""

    def __init__(self, csv_file, root_dir, transform=None):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.landmarks_frame = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.landmarks_frame)

    def __getitem__(self, idx):
        img_name = os.path.join(self.root_dir, self.landmarks_frame.ix[idx, 0])
        image = io.imread(img_name)
        landmarks = self.landmarks_frame.ix[idx, 1:].as_matrix().astype('float')
        landmarks = landmarks.reshape(-1, 2)
        sample = {
  'image': image, 'landmarks': landmarks}
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值