前言
众所周知,Dataset和Dataloder是pytorch中进行数据载入的部件。必须将数据载入后,再进行深度学习模型的训练。在pytorch的一些案例教学中,常使用torchvision.datasets
自带的MNIST、CIFAR-10数据集,一般流程为:
# 下载并存放数据集
train_dataset = torchvision.datasets.CIFAR10(root="数据集存放位置",download=True)
# load数据
train_loader = torch.utils.data.DataLoader(dataset=train_dataset)
但是,在我们自己的模型训练中,需要使用非官方自制的数据集。这时应该怎么办呢?
我们可以通过改写torch.utils.data.Dataset
中的__getitem__
和__len__
来载入我们自己的数据集。
__getitem__
获取数据集中的数据,__len__
获取整个数据集的长度(即个数)。
改写
采用pytorch官网案例中提供的一个脸部landmark数据集。数据集中含有存放landmark的csv文件,但是我们在这篇文章中不使用(其实也可以随便下载一些图片作数据集来实验)。
import os
import torch
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms,