在运行 TensorFlow 例程时,经常需要下载一些数据文件,有些数据文件比较大,网络条件不好的时候会花费较长时间,而且这段时间干等也不是办法,有没有一种一劳永逸的方法避免重复下载数据呢?答案是肯定的,我们从源码中找找答案。
打开 MNIST 例程(https://github.com/tensorflow/tensorflow/blob/r0.12/tensorflow/models/image/mnist/convolutional.py),找到 maybe_download 这个函数,代码如下:
其中 WORK_DIRECTORY 已经在前面第 38 行定义
WORK_DIRECTORY = 'data'
在 maybe_download 函数中,首先判断这个目录是否存在,如果不存在,则创建它。
之后,判断 WORK_DIRECTORY 下的 filename 对应的文件是否存在,如果不存在,则利用 urllib 模块发起 HTTP 请求,从 SOURCE_URL 下载。
SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/'
下载之后会打印相应下载文件大小信息。
看完 maybe_download 这个函数定义,再看看它是怎么调用的,以及传递的 filename 参数究竟是什么。在第 129~133 行我们找到了答案:
可见,主函数在准备数据阶段,下载了以下 4 个文件:
train-images-idx3-ubyte.gz
train-labels-idx1-ubyte.gz
t10k-images-idx3-ubyte.gz
t10k-labels-idx1-ubyte.gz
它们均位于 ./data/ 目录下。
在一台安装了 TensorFlow 的机器上(无论用哪种方法安装),运行这个例子。在任意目录下执行:
python -m tensorflow.models.image.mnist.convolutional
第一次执行,会下载上述 4 个文件,在运行日志中可以看到:
而后面在相同目录下再运行这个例程,则不会再次重复下载,在运行日志中可以看到:
从此,你在其他目录运行这个例程时,记得把 ./data/ 也一并拷贝过去,这样就能省去下载数据的时间。
除了 MNIST 例程之外,其他例程像 CIFAR10,seq2seq,label_image 等都有这个特点,只是数据存放路径略有不同,需要读者自己研究下代码,找到运行时下载的数据文件,备份它们,方便下次使用。
微信扫描如下二维码关注此公众号!