关于numpy.argmax及mnist中one_hot=True

目录

mnist是深度学习中必不可少的数据集,它是一个28*28个像素的数字图片转化之后的数据格式,但是在tensorflow它已经被压缩在文件中,并且利用脚本可以直接下载(input.py下载不下来可以直接百度,能够找到源代码,放在sys.path路径中即可),但是在数据导入时one_hot的含义不是特别清楚

关于one_hot

在minist导入时会指定one_hot=True,如下:

import numpy as np
import tensorflow as tf
import input_data
mnist_data=input_data.read_data_sets("MNIST_data/",one_hot=True)

关于one_hot的具体含义,在导入数据时,将标签不是用指定的0-9的数字出现的,而是以numpy的数组的格式出现的,应该是为了更好地用numpy处理数据,如下:
输入数据

import numpy as np
import tensorflow as tf
import input_data

mnist_data=input_data.read_data_sets("MNIST_data/",one_hot=True)
mnist=mnist_data
print (type(mnist_data.train.labels[8,:]))
print ((mnist_data.train.labels[8,:])

输出内容

##输出结果
Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
<type 'numpy.ndarray'>
[0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]#这个array就代表9

输入数据(未设置one_hot=True)

import numpy as np
import tensorflow as tf
import input_data

mnist_data=input_data.read_data_sets("MNIST_data/")
mnist=mnist_data
print (type(mnist_data.train.labels[8,:]))
print ((mnist_data.train.labels[8,:])

输出内容

Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
<type 'numpy.uint8'>
9

我们可以通过numpy的argmax函数对阵列进行转换,这个函数就是返回阵列中最大数值的索引值(从0开始的索引),对于二维数据则是返回每一列的最大值的索引值
输入数据

>>> import numpy as np
>>> np.argmax([0,0,0,0,0,0,0,0,1])
8
>>> np.argmax([0,0,0,0,0,0,0,0,0,1])
9
  • 0
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
05-10
可以优化代码如下: ``` from sklearn.decomposition import PCA from sklearn.cluster import KMeans from sklearn.metrics import accuracy_score import numpy as np import matplotlib.pyplot as plt from tensorflow.examples.tutorials.mnist import input_data import datetime # 导入数据集 start = datetime.datetime.now() #计算程序运行时间 mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) X_train = mnist.train.images y_train = mnist.train.labels X_test = mnist.test.images y_test = mnist.test.labels # PCA降维 pca = PCA(n_components=10) X_train_pca = pca.fit_transform(X_train) X_test_pca = pca.transform(X_test) # 输出因子负荷量 print("PCA降维后的因子负荷量为:") print(pca.components_) # 可视化 plt.scatter(X_train_pca[:, 0], X_train_pca[:, 1], c=np.argmax(y_train, axis=1)) plt.show() # K-means聚类 kmeans_centers = [] # 用于存储初始类心 for i in range(10): idx = np.where(np.argmax(y_train, axis=1) == i)[0] # 获取第i类数字的索引列表 sample_idx = np.random.choice(idx) # 随机指定一个样本作为初始类心 kmeans_centers.append(X_train_pca[sample_idx]) # 将初始类心添加到列表 kmeans = KMeans(n_clusters=10,init=kmeans_centers,n_init=1) kmeans.fit(X_train_pca) # 计算分类错误率 y_pred = kmeans.predict(X_test_pca) acc = accuracy_score(np.argmax(y_test, axis=1), y_pred) print("分类错误率:{:.2%}".format(1-acc)) # 计算程序运行时间 end = datetime.datetime.now() print("程序运行时间为:"+str((end-start).seconds)+"秒") ``` 输出结果包含了PCA降维后的因子负荷量,即`pca.components_`。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值