导入自定义数据
来源官方教程
数据下载链接:https://download.pytorch.org/tutorial/faces.zip
1、导入库
import matplotlib.pyplot as plt
import os
import numpy as np
import pandas as pd
import warnings
import torchsnooper
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import torch
from skimage import io, transform
from torchvision import transforms, utils
warnings.filterwarnings("ignore")
2、查看单个样本
landmark_frame = pd.read_csv("faces\\face_landmarks.csv")
n = 64
image_name = landmark_frame.iloc[n, 0]
landmark = landmark_frame.iloc[n, 1:]
landmark = np.asarray(landmark)
landmark = landmark.astype(float).reshape(-1, 2)
def show_image(image_file, landmarks):
image = io.imread(image_file)
plt.imshow(image)
plt.scatter(landmark[:, 0], landmark[:, 1], s=10, c='r', marker='.')
plt.pause(0.001)
plt.ion()
image_file = os.path.join('faces', image_name)
image = io.imread(image_file)
plt.figure()
show_image(image_file, landmark)
plt.show()
知识点:
pandas:
DataFrame.iloc[]#像数组一样访问元素
3、制作数据集
总结:
继承类:torch.utils.data.Dataset 初始化:init() 实现接口:len(self)、getitem(self, idx)返回: 返回的类型类似于字典列表,可以通过方括号[]进行索引获得每条数据。 类似于data = [dict1,dict2,dict3],data[0]
此处是继承一个类,并且要实现其接口,接口必须要实现,通过接口使得这个类更加具有灵活 性,想返回什么样类型的,只要将其包装成字典就可以。
class FaceLandMarksDataset(Dataset):
def __init__(self, csv_file, root_dir, transform=None):
super(FaceLandMarksDataset, self).__init__()
self.img_file