在使用tensorflow加载mnist数据集时报错,因为下载数据集的链接被墙了,没法下载数据集。
解决方法:
step1.手动下载数据集到本地。
mnist数据集的格式有两种,一种是gz格式,包含4个文件,一种是npz格式。
.gz格式数据集下载:http://yann.lecun.com/exdb/mnist/
mnist.npz格式数据集下载:https://pan.baidu.com/s/1kbDiH-nnbgmTRdaZM6c80g 提取码:sg2k
step2.加载本地数据集
# 方法1:mnist.npz格式的数据集的加载
def load_data_npz(path='mnist.npz'):
"""
path:mnist.npz文件的路径
"""
f = np.load(path) # np.load文件可以加载npz,npy格式的文件
x_train,y_train,x_test,y_test = f['x_train'], f['y_train'], f['x_test'],f['y_test']
f.close()
return x_train, y_train, x_test, y_test
# 调用load_data函数加载mnist.npz数据集
path = './MNIST_DATA/mnist.npz'
x_train_npz, y_train_npz, x_test_npz, y_test_npz = load_data_npz(path)
# 查看数据集的形状
print('x_train_npz:{}'.format(x_train_