TensorFlow+keras

1导入库

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from tensorflow import keras
%matplotlib inline

2 读取MNIST数据集

mnist = keras.datasets.mnist
(x_train, y_train),(x_test, y_test) = mnist.load_data()

3 重构数据至四维

x_train=x_train.reshape(x_train.shape+(1,))
x_test=x_test.reshape(x_test.shape+(1,))
x_train, x_test = x_train/255.0, x_test/255.0

4 数据标签

label_train = keras.utils.to_categorical(y_train, 10)
label_test = keras.utils.to_categorical(y_test, 10)

5 模型构建

model = tf.keras.models.Sequential([tf.keras.layers.Conv2D(64,7,activation="relu",padding="same",input_shape=[28,28,1]),
                                   tf.keras.layers.MaxPooling2D(2),
                                   tf.keras.layers.Conv2D(128,3,activation="relu",padding="same"),
                                   tf.keras.layers.Conv2D(128,3,activation="relu",padding="same"),
                                   tf.keras.layers.MaxPooling2D(2),
                                   tf.keras.layers.Conv2D(256,3,activation="relu",padding="same"),
                                   tf.keras.layers.Conv2D(256,3,activation="relu",padding="same"),
                                   tf.keras.layers.MaxPooling2D(2),
                                   tf.keras.layers.Flatten(),
                                   tf.keras.layers.Dense(128,activation="relu"),
                                   tf.keras.layers.Dropout(0.5),
                                   tf.keras.layers.Dense(64,activation="relu"),
                                   tf.keras.layers.Dropout(0.5),
                                   tf.keras.layers.Dense(10,activation="softmax")
                                   ])

6 模型显示

model.summary()

7 使用SGD编译模型

model.compile(optimizer="sgd",
             loss="categorical_crossentropy",
             metrics=["acc"])

8 学习20个纪元,使用20%数据交叉验证

history = model.fit(x_train,label_train,epochs=20,validation_split=0.2)

9 预测

y_pred = np.argmax(model.predict(x_test), axis=1)
print("prediction accuracy: {}".format(sum(y_pred==y_test)/len(y_test)))

10 绘制结果

plt.plot(records.history['loss'],label='training set loss')
plt.plot(records.history['val_loss'],label='validation set loss')
plt.ylabel('categorical cross-entropy'); plt.xlabel('epoch')
plt.legend()

11 模型训练精度及结果

11.1 精度

训练精度

11.2 损失曲线

损失曲线

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值