目录
mnist是深度学习中必不可少的数据集,它是一个28*28个像素的数字图片转化之后的数据格式,但是在tensorflow它已经被压缩在文件中,并且利用脚本可以直接下载(input.py下载不下来可以直接百度,能够找到源代码,放在sys.path路径中即可),但是在数据导入时one_hot的含义不是特别清楚
关于one_hot
在minist导入时会指定one_hot=True,如下:
import numpy as np
import tensorflow as tf
import input_data
mnist_data=input_data.read_data_sets("MNIST_data/",one_hot=True)
关于one_hot的具体含义,在导入数据时,将标签不是用指定的0-9的数字出现的,而是以numpy的数组的格式出现的,应该是为了更好地用numpy处理数据,如下:
输入数据
import numpy as np
import tensorflow as tf
import input_data
mnist_data=input_data.read_data_sets("MNIST_data/",one_hot=True)
mnist=mnist_data
print (type(mnist_data.train.labels[8,:]))
print ((mnist_data.train.labels[8,:])
输出内容
##输出结果
Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
<type 'numpy.ndarray'>
[0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]#这个array就代表9
输入数据(未设置one_hot=True)
import numpy as np
import tensorflow as tf
import input_data
mnist_data=input_data.read_data_sets("MNIST_data/")
mnist=mnist_data
print (type(mnist_data.train.labels[8,:]))
print ((mnist_data.train.labels[8,:])
输出内容
Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
<type 'numpy.uint8'>
9
我们可以通过numpy的argmax函数对阵列进行转换,这个函数就是返回阵列中最大数值的索引值(从0开始的索引),对于二维数据则是返回每一列的最大值的索引值
输入数据
>>> import numpy as np
>>> np.argmax([0,0,0,0,0,0,0,0,1])
8
>>> np.argmax([0,0,0,0,0,0,0,0,0,1])
9