还是入坑了的TF2_1——实现全连接手写体识别
学习前言
还是要开始学习tf2呀,看看有没有意思!这玩意好像就是Keras呀,老熟人了,感觉很快乐!
重要函数
1、Model
Model用于建立模型。与Keras一样,可以传入Inputs和Outputs作为输入输出。很简单就可以构建一个模型。
使用方法如下:
# 建立模型
model = Model(inputs,out)
2、Input
Input用于建立输入量。与Keras一样,需要指定输入进来的内容的shape,可以是图片也可以是一维向量之类的。
使用方法如下:
# 作为输入
inputs = Input([28,28])
3、Dense
Dense用于往model中添加全连接层。全连接层示意图如下。
具体而言,简单的BP神经网络中,输入层到隐含层中间的权值连接,其实与全连接层的意义相同。
与Keras一样,需要指定全连接的神经元数量,还可以指定激活函数。
x = Flatten(input_shape=(28, 28))(inputs)
x = Dense(128, activation='relu')(x)
x = Dropout(0.2)(x)
out = Dense(10, activation='softmax')(x)
4、model.compile
model.compile主要用于定义loss函数和优化器。
其调用方式如下:
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
其中loss用于定义计算损失的损失函数,其可以选择的内容如下:
1、mse:均方根误差,常用于回归预测。
2、categorical_crossentropy:亦称作多类的对数损失,注意使用该目标函数时,需要将标签转化为形如(nb_samples, nb_classes)的二值序列,常用于分类。
3、sparse_categorical_crossentrop:如上,但接受稀疏标签。
metrics=[‘accuracy’]常用于分类运算中,accuracy代表计算分类精确度。
5、model.fit
用于接收训练数据用于训练:
# 利用fit进行训练
model.fit(x_train, y_train, epochs=5)
全部代码
import tensorflow as tf
from tensorflow.keras.layers import Flatten,Dense,Dropout,Input
from tensorflow.keras.models import Model
print(tf.__version__)
print(tf.keras.__version__)
# 载入Mnist手写数据集
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
# 作为输入
inputs = Input([28,28])
x = Flatten(input_shape=(28, 28))(inputs)
x = Dense(128, activation='relu')(x)
x = Dropout(0.2)(x)
out = Dense(10, activation='softmax')(x)
# 建立模型
model = Model(inputs,out)
# 设定优化器,loss,计算准确率
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
# 利用fit进行训练
model.fit(x_train, y_train, epochs=5)
model.evaluate(x_test, y_test, verbose=2)
输出如下:
2.0.0
2.2.4-tf
Epoch 1/5
60000/60000 [==============================] - 3s 45us/sample - loss: 0.2989 - accuracy: 0.9121
Epoch 2/5
60000/60000 [==============================] - 2s 36us/sample - loss: 0.1415 - accuracy: 0.9577
Epoch 3/5
60000/60000 [==============================] - 2s 38us/sample - loss: 0.1067 - accuracy: 0.9674
Epoch 4/5
60000/60000 [==============================] - 2s 35us/sample - loss: 0.0893 - accuracy: 0.9726
Epoch 5/5
60000/60000 [==============================] - 2s 36us/sample - loss: 0.0746 - accuracy: 0.9769
10000/1 - 0s - loss: 0.0427 - accuracy: 0.9756