使用pytorch训练自己模型,第一件事以及最重要的事情就是将训练集按pytorch格式导入,并完成预处理, 完整教程代码如下。主要完成人脸训练集导入,自定义数据集类,自定义图像转换预处理方法。主要参考英文教程https://pytorch.org/tutorials/beginner/data_loading_tutorial.html
from __future__ import print_function, division
import os
import torch
import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
# Ignore warnings
import warnings
warnings.filterwarnings('ignore')
plt.ion()
landmarks_frame = pd.read_csv('./data/faces/faces/face_landmarks.csv')
n = 65
img_name = landmarks_frame.iloc[n,0]
landmarks = landmarks_frame.iloc[n, 1:].as_matrix()
landmarks = landmarks.astype('float').reshape(-1,2)
print('Image name:{}'.format(img_name))
print('Landmarks shape:{}'.format(landmarks.shape))
print('First 4 Landmarks:{}'.format(landmarks[:4]))
def show_landmarks(image, landmarks):
plt.imshow(image)
plt.scatter(landmarks[:, 0], landmarks[:,1], s=10, marker='.', c='r')
plt.pause(0.001)
plt.figure()
show_landmarks(io.imread(os.path.join('data/faces/faces/', img_name)), landmarks)
plt.show(