Tensorflow keras 极简神经网络构建与使用
Keras在希腊语中意为号角,它来自古希腊和拉丁文学中的一个文学形象。
以 mnist 数据集为例, 构建一个神经网络实现手写数字的训练与测试, 首先我们需要认识一下 mnist 数据集, mnist 数据集有 6 万张手写图像, 1 万张测试图像. Keras 通过 datase 来下载与使用 mnist 数据集, 下载与读取的代码如下:
mnist = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) =mnist.load_data()
通过下面的代码可以显示手写数字图像:
print(train_labels[0])
for i in range(25):
plt.subplot(5,5,i+1)
plt.xticks([])
plt.yticks([ ])
plt.grid(False)
plt.imshow(train_images[i], cmap=plt.cm.gray)
plt.xlabel(str(train_labels[i]))
plt.show()
对数据 re-scale 到 0~1.0 之间, 对标签进行了 one-hot 编码, 代码如下:
# re-scale to 0~1.0 之间
train_images = train_images / 255.0
test_images = test_images / 255.0
train_labels = one_hot(train_labels)
test_labels = one_hot(test_labels)
其中 one-hot 编码函数如下:
def one_hot(labels):
onehot_labels = np.zeros(shape=[len</