1. 配置库文件
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow import keras
import numpy as np
2. 加载数据集
tensorflow加载数据集比较便捷,下面是在线加载和本地加载方式:
# 下面一行是在线加载方式
# mnist = tf.keras.datasets.mnist
# 下面两行是加载本地的数据集
datapath = r'E:\Pycharm\project\project_TF\.idea\data\mnist.npz'
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data(datapath)
x_train = tf.keras.utils.normalize(x_train, axis=1)
x_test = tf.keras.utils.normalize(x_test, axis=1) #归一化
3. 建立全连接网络模型
model = tf.keras.Sequential([ # 3 个非线性层的嵌套模型
tf.keras.layers.Flatten(), #将多维数据打平
tf.keras.layers.Dense(784, activation='relu'), # 128也行
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(10, activation='softmax') # softmax分类
])
# 打印模型
model.build((None,784,1)) # 这里要先build,告诉模型数据输入格式
print(model.summary())
打印出的网络结构如下:
4.模型编译与训练
model.compile(optimizer='adam', # 优化器
loss='sparse_categorical_crossentropy', # 交叉熵损失函数
metrics=['accuracy']) # 标签
# 训练模型
model.fit(x_train, y_train, epochs=10,verbose=1) # verbose为1表示显示训练过程
5. 模型测试
这里提供两种方式给出精确度。可以据此了解model.evaluate()与model.predict()的区别,加深了解。
第一种(推荐):
#这里是测试模型
val_loss, val_acc = model.evaluate(x_test, y_test) # model.evaluate是输出计算的损失和精确度
print('First test Loss:{:.6f}'.format(val_loss)
第二种(加深理解):
#测试模型方式二
acc_correct = 0
predictions = model.predict(x_test) # model.perdict是输出预测结果
for i in range(len(x_test)):
if (np.argmax(predictions[i]) == y_test[i]): # argmax是取最大数的索引,放这里是最可能的预测结果
acc_correct += 1
print('Second test accuracy:{:.6f}'.format(acc_correct*1.0/(len(x_test))))
至此,完整的全连接神经网络程序完成。
6.程序运行结果
训练集中精确度达到了0.9967,测试集中精确度达到了0.9787。而采用卷积神经网络能够进一步增加精确度。
7.对np.argmax()工作的一点探索
由于对argmax()工作不是很理解,故做了以下测试,最终弄明白了。
i = 0 #测试集第一张图片
plt.imshow(x_test[i],cmap=plt.cm.binary)
plt.show()
print(np.argmax(predictions[i])) # argmax输出的是最大数的索引,predicts[i]是十个分类的权值
print((predictions[i])) # 比如predicts[0]最大的权值是第八个数,索引为7,故预测的数字为7
输出结果如下:
plt.show()输出的是测试集第一个图片(手写数字7);第一个print输出7;第二个print输出了Predictions[0]的值。
通过观察这10个权值可以看出,第八个权值最大,到了1;而argmax()是输出最大数的索引,权值最大的第八个数索引为7,故第一个print输出是7,即预测出了图片的手写数字是7。
因此,能够很清楚地了解argmax的作用及用法。