Tesnsorflow2.0 kears 读取数据

Tensorflow 2.0 的keras库中,提供了读取训练数据集的接口,可以读取
boston_housing,cifar10,cifa100,fashion_mnist,imdb,mnist和reuters等数据集。
在这里插入图片描述
这些数据的访问都是通过load_data()实现的。由于网络的原因,下载经常不成功。分析一下Google提供的源码可以知道load_data的实现并不复杂。学习Tensorflow的同学可以自己下载源数据文件后,自己编写一个load_data函数就可以正常读取训练集了。
下面的代码是mnist的load_data:

@keras_export('keras.datasets.mnist.load_data')
def load_data(path='mnist.npz'):
  """Loads the MNIST dataset.
  Arguments:
      path: path where to cache the dataset locally
          (relative to ~/.keras/datasets).
  Returns:
      Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.
  License:
      Yann LeCun and Corinna Cortes hold the copyright of MNIST dataset,
      which is a derivative work from original NIST datasets.
      MNIST dataset is made available under the terms of the
      [Creative Commons Attribution-Share Alike 3.0 license.](
      https://creativecommons.org/licenses/by-sa/3.0/)
  """
  origin_folder = 'https://storage.googleapis.com/tensorflow/tf-keras-datasets/'
  path = get_file(
      path,
      origin=origin_folder + 'mnist.npz',
      file_hash=
      '731c5ac602752760c8e48fbffcf8c3b850d9dc2a2aedcf2cc48468fc17b673d1')
  with np.load(path, allow_pickle=True) as f:
    x_train, y_train = f['x_train'], f['y_train']
    x_test, y_test = f['x_test'], f['y_test']

    return (x_train, y_train), (x_test, y_test)

主要的功能文件下载是使用了get_file。我们只需要修改一下get_file 函数的cache_folder和cache_subdir参数就可以实现本地mnist数据集读取了。

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

def load_data(path='mnist.npz'):
  origin_folder = 'https://storage.googleapis.com/tensorflow/tf-keras-datasets/'
  path = tf.keras.utils.get_file(
      path,
      origin=origin_folder + 'mnist.npz',
      cache_dir='DataSet/',
      cache_subdir=""
      )
  with np.load(path, allow_pickle=True) as f:
    x_train, y_train = f['x_train'], f['y_train']
    x_test, y_test = f['x_test'], f['y_test']

    return (x_train, y_train), (x_test, y_test)
(x_train,y_train),(x_test,y_test)=load_data(path='mnist.npz')
print(x_train.shape,y_train.shape,x_test.shape,y_test.shape)

get_file 的功能主要是下载文件、保存文件以及从cache中读取缓存文件。我们已经下载好了mnist.npz,就可以直接从缓存中读取文件,不用下载了。但是origin还是要提供正确的url格式,因为get_file会检查。
cache_dir这个参数设置为mnist.npz所在的目录。cache_subdir设置为空。
下面的代码展示了文件中的内容:

index=89
img = x_train[index]
plt.imshow(img)
print(y_train[index])

在这里插入图片描述

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值