关于fashion_mnist数据集离线加载的问题
一名正在利用Python学习机器学习的新手,随笔记录。
在使用fashion_mnist数据集的时候,出现一个问题:由于数据集全部都是.gz文件。将数据集下载到本地后,不知怎样才能将数据集在jupyter notebook里加载。
(fashion_mnist在github的文档说可以用mnist_reader来做,但自己并不能尝试成功 )
因此,在网上搜索方法许久后,找到了一位博主的做法,亲测有效:
具体代码如下:
def load_data(data_folder): # data_folder为保存的文件夹目录
files = [
'train-labels-idx1-ubyte.gz','train-images-idx3-ubyte.gz',
't10k-labels-idx1-ubyte.gz','t10k-images-idx3-ubyte.gz'
]
paths = []
for fname in files:
paths.append(os.path.join(data_folder,fname))
with gzip.open(paths[0],'rb') as lbpath:
y_train = np.frombuffer(lbpath.read(),np.uint8,offset=8)
with gzip.open(paths[1],'rb') as imgpath:
X_train = np.frombuffer(imgpath.read(),np.uint8,offset=16).reshape(len(y_train),28,28)
with gzip.open(paths[2],'rb') as lbpath:
y_test = np.frombuffer(lbpath.read(),np.uint8,offset=8)
with gzip.open(paths[3],'rb') as imgpath:
X_test = np.frombuffer(imgpath.read(),np.uint8,offset=16).reshape(len(y_test),28,28)
return (X_train,y_train),(X_test,y_test)
(train_images,train_labels),(test_images,test_labels) = load_data('./data\\fashion\\')