MNIST读取图片

MNIST数据集

这是一组由美国高中生和人口调查局员工手写的70 000个数字的图片。每张图片都用其代表的数字标记。这个数据集被广为使用,因此也被称作是机器学习领域的“Hello
World”:但凡有人想到了一个新的分类算法,都会想看看在MNIST上的执行结果。因此只要是学习机器学习的人,早晚都要面对MNIST。

获取数据集

from sklearn.datasets import fetch_openml
mnist = fetch_openml('mnist_784',version = 1)
mnist.keys() #字典

让我们分解这个调用:

  • fetch_openml:这是sklearn.datasets模块中的一个函数,用于从OpenML平台获取数据集。它允许用户指定数据集的名称、版本等参数,以检索所需的数据集。
  • 'mnist_784':这是传递给fetch_openml函数的第一个参数,指定了要下载的数据集的名称。在这个例子中,它指的是MNIST数据集的一个版本,其中数字图像被展平为784维的向量(因为MNIST图像是28x28像素的,所以28 * 28=784)。
  • version=1:这是传递给fetch_openml函数的一个关键字参数,用于指定要下载的数据集的版本。在OpenML上,数据集可能有多个版本,每个版本可能包含不同的数据或具有不同的预处理方式。在这个例子中,version=1指定了MNIST数据集的一个特定版本。

Scikit-Learn加载的数据集通常具有类似的字典结构,包括:

  • DESCR键,描述数据集。

  • data键,包含一个数组,每个实例为一行,每个特征为一列。

  • target键,包含一个带有标记的数组。

结果如下:

dict_keys(['data', 'target', 'frame', 'categories', 'feature_names', 'target_names', 'DESCR', 'details', 'url'])

接下来引入变量:

X =mnist['data']
y = mnist['target']

y表示X数据的数字标签。

import matplotlib as mlp
import matplotlib.pyplot as plt
import numpy as np
some_digit = np.array(X.iloc[3003,]) #逗号可要可不要 只需表达取一整行标签数据
some_digit_image = some_digit.reshape(28,28)
plt.imshow(some_digit_image, cmap = mlp.cm.binary, interpolation='bilinear') 
plt.axis('off') #关闭坐标轴显示
plt.show() 

结果如下:
![[QQ_1724669224009.png]]

让我们来看看y 的值是多少:

y[30003]
# 9

显示单个图片的函数:

def plot_digit(data):
	image = data.reshape(28,28)
	plt.imshow(image,cmap = mpl.cm.binary,interpolation = "nearest")
	plt.axis("off")

调用显示单个图片的函数

plt.figure(figsize = (9,9))
examplt_image = np.array(X.iloc[100])
plot_digit(example_image)
plt.show()
print(y[100])

![[QQ_1724729742350.png]]

显示更多图片的函数

def plot_images(instances , images_per_row = 10 , **options):
	images_per_row = min(len(instances) , images_per_row) 
	#每一行多少张。当总数不足10时。
	images =[ instance.reshape(size,size) for instance in instances]
	n_rows = (len(instances)-1) // images_per_row +1
	n_empty = n_rows * images_per_row - len(instances) 
    images.append(np.zeros((size, size * n_empty))) #最后一行不够的补上空白格
	row_images = []
	for row in range(n_rows):
		rimages  = images[row * images_per_row : (row +1 ) * images_per_row]
		row_images.append(np.concatenate(rimages,axis = 1))
	image = np.concatenate(row_images,axis = 0)
	plt.imshow(images,cmap = mpl.cm.binary ,**options)
	plt.axis("off")

需要明白的是 此处的rimages 是一个数组,即row_images.append 添加的是一个数组
这样才能用 np.concatenate(row_images,axis = 0)
调用函数

plt.figure(figsize= (9,9))
example_images = np.array(X[:100])
plot_digits(example_images , images_per_row = 10)
plt.show()

结果如下:
![[QQ_1724733290811.png]]

  • 5
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值