在线下载数据集太慢了,我们尝试手动下载,自行导入本地数据集,以FashionMnist为例:
torch.utils.data.Dataset 是一个抽象类,用户想要加载自定义的数据集只需要继承这个类,并且覆写其中的三个方法即可:
1.__init__:数据集的初始化,加载等;
2.__getitem__:用于获取一些指定索引的数据,返回数据集中指定位置的样本;
3.__len__:实现len(dataset),返回整个数据集的大小。
注意:不覆写这三个方法会直接返回错误。
def __getitem__(self, index):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
建立的自定义类如下:
from torch.utils.data import Dataset
import gzip
import numpy as np
class DealDataset(Dataset):
"""
读取数据、初始化数据
"""
def __init__(self, folder, data_name, label_name, transform=None):
(train_set, train_labels) = load_data(folder, data_name,label_name)
self.train_set = train_set
self.train_labels = train_labels
self.transform = transform
def __getitem__(self, index):
img, target = self.train_set[index], int(self.train_labels[index])
if self.transform is not None:
img = self.transform(img)
return img, target
def __len__(self):
return len(self.train_set)
def load_data(data_folder, data_name, label_name):
with gzip.open(os.path.join(data_folder, label_name), 'rb') as lbpath:
y_train = np.frombuffer(lbpath.read(), np.uint8, offset=8)
with gzip.open(os.path.join(data_folder, data_name), 'rb') as imgpath:
x_train = np.frombuffer(
imgpath.read(), np.uint8, offset=16).reshape(len(y_train), 28, 28)
return (x_train, y_train)
设计好数据类后,就可以使用torch.utils.data.DataLoader来加载数据集,并访问它了。
trainDataset = DealDataset(r'D:\MachineLearning\data\FashionMNIST\raw\\',
r'train-images-idx3-ubyte.gz',
r'train-labels-idx1-ubyte.gz')
Over