文章目录
- PyTorch 官网教材之 数据加载和处理教程
-
- 0. 官网链接
- 1. 加载 相关包
- 2. 使用路径 直接读取数据,找到某一张图片的名称及其 landmarks
- 3. 显示一张人脸图片及其 landmarks(直接读取图片和标签 来实现)
- 4. 使用 torch.utils.data.Dataset 构造数据集(Dataset class)
- 5. 显示多张图片及其 landmarks(通过访问 torch 的torch.utils.data.Dataset 构造的数据集实现 )
- 6. Transforms 转变(图像变,同时其 landmarks 也要改变)
- 7. Compose transforms 组合转换
- 8. Iterating through the dataset 通过数据集进行迭代
- 9. 总结 Afterword: torchvision
PyTorch 官网教材之 数据加载和处理教程
0. 官网链接
1. DATA LOADING AND PROCESSING TUTORIAL
1. 加载 相关包
from __future__ import print_function, division
import os
import torch
import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
# Ignore warnings
import warnings
warnings.filterwarnings("ignore")
plt.ion() # interactive mode
python
2. 使用路径 直接读取数据,找到某一张图片的名称及其 landmarks
landmarks_frame = pd.read_csv('data/faces/face_landmarks.csv') # 所有图片的标签值
n = 65 # 第65个数据(图片)
img_name = landmarks_frame.iloc[n, 0] # 每 n 行的第0个数据是 图片的名称
landmarks = landmarks_frame.iloc[n, 1:].as_matrix() # 68个landmark points(描述一个point 需要 2个坐标,x坐标和y坐标)
landmarks = landmarks.astype('float').reshape(-1, 2) # 将 landmarks reshape 成坐标对的形式
print('Image name: {}'.format(img_name)) # Image name: person-7.jpg
print('Landmarks shape: {}'.format(landmarks.shape)) # Landmarks shape: (68, 2)
print('First 4 Landmarks: {}'.format(landmarks[:4])) # First 4 Landmarks: [[32. 65.] [33. 76.] [34. 86.] [34. 97.]]
3. 显示一张人脸图片及其 landmarks(直接读取图片和标签 来实现)
def show_landmarks(image, landmarks):
"""Show image with landmarks"""
plt.imshow(image) # 图片
plt.scatter(landmarks[:, 0], landmarks[:, 1], s=10, marker='.', c='r') # 68个点landmarks
plt.pause(0.001) # pause a bit so that plots are updated
plt.figure()
show_landmarks(io.imread(os.path.join('data/faces/', img_name)), landmarks)
plt.show()
4. 使用 torch.utils.data.Dataset 构造数据集(Dataset class)
class FaceLandmarksDataset(Dataset):
"""Face Landmarks dataset."""
def __init__(self, csv_file, root_dir, transform=None):
"""
Args:
csv_file (string): Path to the csv file with annotations.
root_dir (string): Directory with all the images.
transform (callable, optional): Optional transform to be applied
on a sample.
"""
self.landmarks_frame = pd.read_csv(csv_file)
self.root_dir = root_dir
self.transform = transform
def __len__(self):