如何对TensorFlow_DataSets下的MNIST数据集进行画图

背景

TensorFlow平台作为最常用的机器学习平台非常适合初学者学习研究。其中 TensorFLow_DataSets模块自带的数据集,更是免去了学习过程中数据搜集不到的尴尬,这里作者简要介绍其中的一种MNIST数据集的画图操作

代码内容

详细代码内容如下:

mport tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow import keras
import matplotlib.pyplot as plt
mnist_train,mnist_test=tfds.load(name="mnist",split=['train','test'],as_supervised=True)
mnist_sample=mnist_train.take(5)
for i,(image,label) in enumerate(tfds.as_numpy(mnist_sample)):
    plt.subplot(1,5,i+1)
    plt.imshow(image,cmap='gray')
    plt.title(f'label {label}')
    plt.axis('off')
plt.show()

这里的代码首先使用tensorflow_datasets (简称tfds)读取mnist数据集,需要注意的是第一次读取过程中需要从服务器上下载数据,因此会比较慢,方法中的name属性表示要下载的数据集名称,split属性表示将数据集切片为“训练集”和“测试集”,as_supervised属性设置导出的数据集是否包含Label属性。

mnist_train,mnist_test=tfds.load(name=“mnist”,split=[‘train’,‘test’],as_supervised=True)```

接下来从训练集中选取5个样本作为样本集
···
mnist_sample=mnist_train.take(5)
···
最后开始画图。使用tfds自带的as_numpy方法将样本集转换为numpy数组,对该数组进行遍历,可以得到遍历索引和一个包含图片像素以及标签的元祖。样本一共有五个,因此添加5个子图,调用plt.imshow()方法可根据图片像素值进行画图处理,在每张完成的图添加一个样本label属性标签。最后 调用show()方法展现图片

···
for i,(image,label) in enumerate(tfds.as_numpy(mnist_sample)):
    plt.subplot(1,5,i+1)
    plt.imshow(image,cmap='gray')
    plt.title(f'label {label}')
    plt.axis('off')
plt.show()
···
  • 8
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值