MNIST手写数字识别——TensorFlow版本

MNIST手写数字识别——TensorFlow版本
  • python版本:python = 3.6.15
  • anaconda配置python安装Tensorflow:
    • 1.创建版本号为3.6的Tensorflow虚拟环境conda create -n Tensorflow python=3.6
    • 2.激活Tensorflow环境conda activate Tensorflow
    • 3.安装Tensorflowpip install tensorflow -i https://pypi.tuna.tsinghua.edu.cn/simple
零、引用相应的库
  • tensorflow.keras.datasets:导入MNIST手写数字数据库
  • tensorflow.keras.layers:建立CNN各个层(卷积,池化,全连接)
  • tensorflow.keras.model:建立CNN模型
  • matplotlib.pyplot:进行数据可视化
import tensorflow as tf
from tensorflow.keras import datasets,layers,models,losses
import matplotlib.pyplot as plt
一、对数据的处理以及可视化
1.1 数据处理
# 导入数据
(train_ds,train_label),(test_ds,test_label) = datasets.mnist.load_data()

# 归一化
train_ds = train_ds / 255.0
test_ds = test_ds / 255.0

# reshape
train_ds = train_ds.reshape((60000,28,28,1))
test_ds = test_ds.reshape((10000,28,28,1))

1.2 数据可视化
# 数据可视化
plt.figure(figsize=(20,10))
for i in range(20):
    plt.subplot(2,10,i+1)
    # 不显示XY轴
    plt.xticks([])
    plt.yticks([])
    plt.grid(False)
    # plt.imshow(train_ds[i]) # 有颜色的
    plt.imshow(train_ds[i],cmap=plt.cm.binary) # cmap是颜色图谱的参数,是彩色的
    plt.xlabel(train_label[i])
plt.show()
二、模型的搭建与训练
# model
model = models.Sequential([
    # 输入 28*28*1--26*26*1
    layers.Conv2D(32,(3,3),1,activation='relu',input_shape = (28,28,1)),
    # 池化 maxpool 26*26*1 -- 13*13*1
    layers.MaxPooling2D((2,2),2),
    # 输入 13*13*1 -- 11*11*1
    layers.Conv2D(64,(3,3),1,activation='relu'),
    # 池化 maxpool 11*11*1 -- 5*5*1
    layers.MaxPooling2D((2,2),2),
    # 展平
    layers.Flatten(),
    # 全连接隐藏层
    layers.Dense(64,activation='relu'),
    # 全连接Output
    layers.Dense(10)
])
#输出model的模式图:
model.summary()

# 模型的编译
model.compile(
    optimizer = 'adam',
    loss = losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=['accuracy']
)

# 模型的训练
model.fit(
    train_ds,
    train_label,
    epochs=10,
    validation_data=(test_ds,test_label)
)
三、训练结果的可视化
# Epochs--Loss关系可视化
# Epochs--Accuracy关系可视化
def loss_draw(history):
    loss = history.history['loss']
    epochs = range(1,len(loss)+1)
    # 第一张子图 epochs 和 loss
    plt.subplot(1,2,1)
    plt.plot(epochs,loss,'bo',label = 'Training Loss')
    # 子图的标题
    plt.title('Training Loss')
    # x轴和y轴
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    # 图例
    plt.legend()

    plt.subplot(1,2,2)
    accuracy = history.history['accuracy']
    plt.plot(epochs,accuracy,'bo',label = 'Training Accuracy')
    plt.title('Training Accuracy')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.suptitle('Train data')
    plt.legend()
    plt.show()

loss_draw(history=history)
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值