from __future__ import print_function, division
import os
import torch
import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import cv2
# Ignore warnings
import warnings
warnings.filterwarnings("ignore")
# 数据集类
class FaceLandmarksDataset(Dataset):
"""Face Landmarks dataset."""
def __init__(self, csv_file, root_dir, transform=None):
"""
Args:
csv_file (string): Path to the csv file with annotations.
root_dir (string): Directory with all the images.
transform (callable, optional): Optional transform to be applied
on a sample.
"""
self.landmarks_frame = pd.read_csv(csv_file)
self.root_dir = root_dir
self.transform = transform
def __len__(self):
return len(self.landmarks_frame)
def __getitem__(self, idx):
img_name = os.path.join(self.root_dir,
self.landmarks_frame.iloc[idx, 0])
image = io.imread(img_name)
landmarks = self.landmarks_frame.iloc[idx, 1:].values
landmarks = landmarks.astype('float').reshape(-1, 2)
sample = {'image': image, 'landmarks': landmarks}
if self.transform:
sample = self.transform(sample)
return sample
def show_landmarks(image, landmarks):
"""Show image with landmarks"""
plt.imshow(image)
plt.scatter(landmarks[:, 0], landmarks[:, 1], s=10, marker='.', c='r')
#plt.plot(landmarks[:, 0], landmarks[:, 1],'b')
# 绘线图后有点狰狞(-_-!)
plt.pause(0.001) # pause a bit so that plots are updated
face_dataset = FaceLandmarksDataset(csv_file='../data/faces/face_landmarks.csv',
root_dir='../data/faces/')
fig = plt.figure()
for i in range(len(face_dataset)):
sample = face_dataset[i] # __getitem__(self,idx)
print(i, sample['image'].shape, sample['landmarks'].shape)
ax = plt.subplot(2, 3, i + 1)
plt.tight_layout() # 自动调整子图参数,使之填充整个图像区域
ax.set_title('Sample #{}'.format(i))
ax.axis('off') # 关闭xy轴
show_landmarks(**sample)
# *收集分配参数 **收集分配字典
if i == 5:
plt.show()
break