在用TensorFlow的mnist数据集做手写数字识别任务时,使用tensorflow自带的模块(如下所示)下载和导入数据集会报错,原因是该模块爬取的数据集网站不能访问。。因为该模块是用python内置urllib
模块来下载数据的,需要提供有效的数据集网站地址。
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets(data_dir, one_hot=True)
首先我看看read_data_sets()
函数源码:
def read_data_sets(train_dir,
fake_data=False,
one_hot=False,
dtype=dtypes.float32,
reshape=True,
validation_size=5000,
seed=None,
source_url=DEFAULT_SOURCE_URL):
if fake_data:
def fake():
return DataSe