背景
TensorFlow平台作为最常用的机器学习平台非常适合初学者学习研究。其中 TensorFLow_DataSets模块自带的数据集,更是免去了学习过程中数据搜集不到的尴尬,这里作者简要介绍其中的一种MNIST数据集的画图操作
代码内容
详细代码内容如下:
mport tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow import keras
import matplotlib.pyplot as plt
mnist_train,mnist_test=tfds.load(name="mnist",split=['train','test'],as_supervised=True)
mnist_sample=mnist_train.take(5)
for i,(image,label) in enumerate(tfds.as_numpy(mnist_sample)):
plt.subplot(1,5,i+1)
plt.imshow(image,cmap='gray')
plt.title(f'label {label}')
plt.axis('off')
plt.show()
这里的代码首先使用tensorflow_datasets (简称tfds)读取mnist数据集,需要注意的是第一次读取过程中需要从服务器上下载数据,因此会比较慢,方法中的name属性表示要下载的数据集名称,split属性表示将数据集切片为“训练集”和“测试集”,as_supervised属性设置导出的数据集是否包含Label属性。
mnist_train,mnist_test=tfds.load(name=“mnist”,split=[‘train’,‘test’],as_supervised=True)```
接下来从训练集中选取5个样本作为样本集
···
mnist_sample=mnist_train.take(5)
···
最后开始画图。使用tfds自带的as_numpy方法将样本集转换为numpy数组,对该数组进行遍历,可以得到遍历索引和一个包含图片像素以及标签的元祖。样本一共有五个,因此添加5个子图,调用plt.imshow()方法可根据图片像素值进行画图处理,在每张完成的图添加一个样本label属性标签。最后 调用show()方法展现图片
···
for i,(image,label) in enumerate(tfds.as_numpy(mnist_sample)):
plt.subplot(1,5,i+1)
plt.imshow(image,cmap='gray')
plt.title(f'label {label}')
plt.axis('off')
plt.show()
···