关于numpy.argmax及mnist中one_hot=True

目录

mnist是深度学习中必不可少的数据集,它是一个28*28个像素的数字图片转化之后的数据格式,但是在tensorflow它已经被压缩在文件中,并且利用脚本可以直接下载(input.py下载不下来可以直接百度,能够找到源代码,放在sys.path路径中即可),但是在数据导入时one_hot的含义不是特别清楚

关于one_hot

在minist导入时会指定one_hot=True,如下:

import numpy as np
import tensorflow as tf
import input_data
mnist_data=input_data.read_data_sets("MNIST_data/",one_hot=True)

关于one_hot的具体含义,在导入数据时,将标签不是用指定的0-9的数字出现的,而是以numpy的数组的格式出现的,应该是为了更好地用numpy处理数据,如下:
输入数据

import numpy as np
import tensorflow as tf
import input_data

mnist_data=input_data.read_data_sets("MNIST_data/",one_hot=True)
mnist=mnist_data
print (type(mnist_data.train.labels[8,:]))
print ((mnist_data.train.labels[8,:])

输出内容

##输出结果
Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
<type 'numpy.ndarray'>
[0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]#这个array就代表9

输入数据(未设置one_hot=True)

import numpy as np
import tensorflow as tf
import input_data

mnist_data=input_data.read_data_sets("MNIST_data/")
mnist=mnist_data
print (type(mnist_data.train.labels[8,:]))
print ((mnist_data.train.labels[8,:])

输出内容

Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
<type 'numpy.uint8'>
9

我们可以通过numpy的argmax函数对阵列进行转换,这个函数就是返回阵列中最大数值的索引值(从0开始的索引),对于二维数据则是返回每一列的最大值的索引值
输入数据

>>> import numpy as np
>>> np.argmax([0,0,0,0,0,0,0,0,1])
8
>>> np.argmax([0,0,0,0,0,0,0,0,0,1])
9
  • 0
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值