MNIST数据集一般有两种使用方法,其中一种在torchvision中已经包装好了,这里讲解手动加载MNIST数据集的方法
下载
首先在官网下载MNIST数据集,地址,一共有四个压缩包,下载后解压即可
读取数据
复制下面代码到readdata.py中,然后给定数据集路径读取即可
import os
import gzip
import numpy as np
from torch.utils.data import Dataset
'''
load data
- data_folder: MNIST folder name
- data_name: MNIST data name
- label_name: MNIST lable name
'''
def load_data(data_folder, data_name, label_name):
with gzip.open(os.path.join(data_folder, label_name), 'rb') as lbpath: # rb表示的是读取二进制数据
y_train = np.frombuffer(lbpath.read(), np.uint8, offset=8)
with gzip.open(os.path.join(data_folder, data_name), 'rb') as imgpath:
x_train = np.frombuffer(
imgpath.read(), np.uint8, offset=16).reshape(len(y_train), 28, 28)
return (x_train, y_train)
class CustomDataset(Dataset):
"""
读取数据、初始化数据
"""
def __init__(self, folder, data_name, label_name,transform=None):
(train_set, train_labels) = load_data(folder, data_name, label_name) # 其实也可以直接使用torch.load(),读取之后的结果为torch.Tensor形式
self.train_set = train_set
self.train_labels = train_labels
self.transform = transform
def __getitem__(self, index):
img, target = np.array(self.train_set[index]), int(self.train_labels[index])
if self.transform is not None:
img = self.transform(img)
return img, target
def __len__(self):
return len(self.train_set)