国内如何使用tensorflow_datasets加载mnist数据
国内使用tensorflow_datasets无法下载数据,可以通过kaggle下载或转格式的方法解决。
报错
import numpy as np
import tensorflow_datasets as tfds
dataset, info = tfds.load(
"mnist", data_dir="gs://tfds-data/datasets", with_info=True, as_supervised=True
)
报错:
2024-08-23 17:25:05.371154: W tensorflow/tsl/platform/cloud/google_auth_provider.cc:184] All attempts to get a Google authentication bearer token failed, returning an empty token. Retrieving token from files failed with "NOT_FOUND: Could not locate the credentials file.". Retrieving token from GCE failed with "FAILED_PRECONDITION: Error executing an HTTP request: libcurl code 6 meaning 'Couldn't resolve host name', error details: Could not resolve host: metadata.google.internal".
2024-08-23 17:26:06.386193: E tensorflow/tsl/platform/cloud/curl_http_request.cc:610] The transmission of request 0x213a5b00 (URI: https://www.googleapis.com/storage/v1/b/tfds-data/o/datasets%2Fmnist?fields=size%2Cgeneration%2Cupdated) has been stuck at 0 of 0 bytes for 61 seconds and will be aborted.
。。。。
RuntimeError: AbortedError: Failed to construct dataset mnist: All 10 retry attempts failed. The last failure: Error executing an HTTP request: libcurl code 42 meaning 'Operation was aborted by an application callback', error details: Callback aborted
when reading metadata of gs://tfds-data/datasets/mnist
方法一:使用kaggle下载数据
在kaggle设置缓存tfds.load缓存路径,下载到本地
设置缓存路径,下载数据
import tensorflow_datasets as tfds
dataset, info = tfds.load(
"mnist", data_dir="tfds-data/datasets", with_info=True, as_supervised=True
)
2024-08-24 10:08:22.143742: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-08-24 10:08:23.073311: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
2024-08-24 10:08:24.855020: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1635] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 21127 MB memory: -> device: 0, name: NVIDIA GeForce RTX 4090, pci bus id: 0000:65:00.0, compute capability: 8.9
其中,data_dir为缓存地址,设置后tfds.load会将数据保存在该路径
压缩缓存文件夹
!zip -r datasets.zip tfds-data/datasets
updating: tfds-data/datasets/ (stored 0%)
updating: tfds-data/datasets/mnist/ (stored 0%)
updating: tfds-data/datasets/mnist/3.0.1/ (stored 0%)
updating: tfds-data/datasets/mnist/3.0.1/image.image.json (deflated 4%)
updating: tfds-data/datasets/mnist/3.0.1/mnist-test.tfrecord-00000-of-00001 (deflated 27%)
updating: tfds-data/datasets/mnist/3.0.1/mnist-train.tfrecord-00000-of-00001 (deflated 27%)
updating: tfds-data/datasets/mnist/3.0.1/features.json (deflated 73%)
updating: tfds-data/datasets/mnist/3.0.1/dataset_info.json (deflated 73%)
updating: tfds-data/datasets/downloads/ (stored 0%)
updating: tfds-data/datasets/downloads/extracted/ (stored 0%)
下载压缩文件
在kaggle页面右侧output目录下找到datasets.zip文件,下载到本地。
解压并加载数据
-
将数据解压至工作目录如"tfds-data/datasets"
-
使用data_dir设置tfds.load的缓存路径,能够直接加载已经下好的数据
dataset, info = tfds.load(
"mnist", data_dir="tfds-data/datasets", with_info=True, as_supervised=True
)
mnist_train, mnist_test = dataset["train"], dataset["test"]
方法二:转格式方法
通过手动或其他三方库加载数据后,将其转为tf.data.Dataset数据格式。本文使用keras加载mnist数据并转换格式作为示例。
使用keras的datasets加载mnist数据集
import tensorflow as tf
from tensorflow.keras import datasets
mnist_train, mnist_test = datasets.mnist.load_data()
# 调整数据维度与tensorflow_datasets一致
mnist_train = (mnist_train[0].reshape((60000, 28, 28, 1)),mnist_train[1])
mnist_test = (mnist_test[0].reshape((10000, 28, 28, 1)),mnist_test[1])
# 将数据转换为tf.data.Dataset格式
mnist_train = tf.data.Dataset.from_tensor_slices(mnist_train)
mnist_test = tf.data.Dataset.from_tensor_slices(mnist_test)
然后mnist_train、mnist_test就可以用于后续模型的训练与测试了。