1.用卷积训练mnist
自己照着网上写了个用keras卷积模型训练mnist数据集的代码,熟悉一下keras的一些知识。
import tensorflow as tf
class myCallback(tf.keras.callbacks.Callback):
def on_epoch_end(self, epoch, logs={}):
if (logs.get('accuracy') > 0.998):
print("\n准确度达到99.8%,停止训练")
self.model.stop_training = True
callbacks = myCallback()
mnist = tf.keras.datasets.mnist
(training_images, training_labels), (test_images, test_labels) = mnist.load_data()
training_images, test_images = training_images.reshape(60000, 28, 28, 1) / 255.0, test_images.reshape(10000, 28, 28,1) / 255.0 #reshape,将输出调整为特定形状
model = tf.keras.models.Sequential([
tf.keras.layers.Conv2D(64, (3, 3), activation='relu', input_shape=(28, 28, 1)),#当该层作为第一层时,需要提供input_shape参数(整数元组,不包含样本表示的轴)
#64指的是卷积核的数量,(3,3)指的是卷积核的大小
tf.keras.layers.MaxPooling2D(2, 2),
tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
tf.keras.layers.MaxPooling2D(2, 2),
tf.keras.layers.Flatten(),#将输入展平。不影响批量大小。
tf.keras.layers.Dense(512, activation=tf.nn.relu),
tf.keras.layers.Dense(10, activation=tf.nn.softmax)
])
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
# model fitting
history = model.fit(
training_images,
training_labels,
epochs=20,
callbacks=[callbacks])
keras的Conv层只提供自带的卷积核,自定义卷积核需要一些麻烦的操作,还没学会。Tensorflow中的tensorflow.nn模块是TensorFlow用于深度学习计算的核心模块。
跑了好长时间,训练速度有待优化…
如果只看keras的知识的话,基本就是看着忘着,或者留下个大致印象,只有不断打代码巩固才能掌握的更牢固。
2.普通网络训练mnist
import numpy as np
import tensorflow.keras as keras
import tensorflow as tf
import matplotlib.pyplot as plt
data=keras.datasets.mnist
(train_images,train_labels),(test_images,test_labels)=data.load_data()
class_name=['0','1','2','3','4','5','6','7','8','9','10']
train_images=train_images/255.0
test_images=test_images/255.0
model=keras.Sequential([keras.layers.Flatten(input_shape=(28,28)),
keras.layers.Dense(128,activation='relu'),
keras.layers.Dense(10,activation='softmax')])
model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy'])
model.fit(train_images,train_labels,epochs=5)
#这上边的代码大家应该都熟悉,下边用plt将预测的图片显示出来
prediction=model.predict(test_images)
for i in range(5):
plt.grid(False)
plt.imshow(test_images[i])
plt.xlabel('Actual:'+class_name[test_labels[i]])
plt.title("prediction:"+class_name[np.argmax(prediction[i])])
plt.show()
预测的与结果一致。
3.零碎的函数
3.1显示数据集信息
print(train_images.shape)#会显示(60000,28,28)
3.2显示模型信息
看视频的时候,有个函数model.summary(),print的话,就会显示出来网络的信息。当然,model名字是自定的。