使用python3.8+tensorflow对mnist数据集进行训练
tensorflow2.X新特性
从运行机制上讲,Tensorflow1.x 与 Tensorflow2.0的主要区别在于tf1.x使用静态图而tf2.x使用Eager Mode动态图。对于我们这种框架使用者来说,就是看API是否易于调用,tf2.x在API上进行了优化,不再需要我们创建session对象和运行,而是框架自动创建和运算。具体参考另外一个大神的分析:https://blog.csdn.net/keeppractice/article/details/105934521?utm_medium=distribute.pc_relevant.none-task-blog-BlogCommendFromBaidu-5.nonecase&depth_1-utm_source=distribute.pc_relevant.none-task-blog-BlogCommendFromBaidu-5.nonecase
基于python3.8和tensorflow2.0实现mnist数据集训练
具体代码如下(有简单注释):
import tensorflow as tf
import matplotlib.pyplot as plt
mnist = tf.keras.datasets.mnist
(x_train,y_train),(x_test,y_test) = mnist.load_data()
#设置超参数
batch_size = 128epoch = 10
#数据降维
x_train = x_train.reshape(60000,28*28)
x_test = x_test.reshape(10000,28*28)
#数据归一化处理
x_train = x_train/255.0
x_test = x_test/255.0
#控制台打印样本信息
#print(x_train[0],'train samples')
#print(x_test[0],'test samples')
#对标签进行one_hot编码,将标签转为二进制表示形式
y_train = tf.one_hot(y_train,10)
y_test = tf.one_hot(y_test,10)
#创建模型
model = tf.keras.Sequential()
model.add(tf.keras.layers.Input(shape=x_train.shape[1:]))
model.add(tf.keras.layers.Dense(30,activation=tf.nn.relu))
model.add(tf.keras.layers.Dense(10,activation=tf.nn.softmax))
model.summary()
#设置模型参数
ops = tf.keras.optimizers.SGD(learning_rate=0.005) #优化器
#lossf = tf.keras.losses.categorical_crossentropy(y_train,y_pred) #损失函数
#metric = tf.keras.metrics.Accuracy() #评价标准
model.compile(
#optimizer='sgd',
optimizer=ops,
loss='categorical_crossentropy',
metrics='accuracy')
#训练模型
model.fit(
x = x_train,
y = y_train,
batch_size = batch_size,
epochs= epoch,
validation_data=(x_test,y_test)
)
#测试预测结果
n = 10
predict = model.predict(x_test[:10],n)
fig = plt.figure(figsize=(10,2))
for i in range(n):
plt.subplot(1,10,i+1)
p = x_test[i].reshape(28,28)
plt.imshow(p,cmap='gray')
plt.legend()
if tf.argmax(y_test[i])== tf.argmax(predict[i]):
plt.title(str(tf.argmax(y_test[i]).numpy()) + ',' + str(tf.argmax(predict[i]).numpy()),color='green')
else:
plt.title(str(tf.argmax(y_test[i]).numpy()) + ',' + str(tf.argmax(predict[i]).numpy()),color='red')
plt.show()
这是我进行简单的训练得到的结果,迭代次数仅仅为10次,用时10s左右,其训练样本的准确率为89.3%,测试样本的成功率为89.7%,样本为欠拟合,还能继续增加训练次数或者调整超参数进行优化训练: ![Alt]
这是使用训练好的模型去预测测试数据集,选取其中一部分做的可视化:
鸣谢:
菜鸟一枚,如果有什么不足的地方,恳请指出。有什么疑问,也可以提出,相互学习进步。