tensorflow2.0---2.利用tensorflow2.0进行mnist分类实战

有了一些tensorflow2.0的基础,然后加上之前对tensorflow也有了解,所以今天用tensorflow2.0做个mnist的分类实战,这也相当于“hello world”吧。

首先看看最后的训练结果图,网络结构图

在这里插入图片描述
在这里插入图片描述
可以看到网络结构其实很简单,就是输入层,然后两层中间层,最后一个输出层用softmax

代码部分

import tensorflow as tf
import numpy as np
import math
from tensorflow import keras
from tensorflow.keras import datasets,layers,optimizers,Sequential
from tensorflow.keras.callbacks import TensorBoard
import datetime
import matplotlib.pyplot as plt
plt.rcParams['axes.unicode_minus']=False


#导入数据
(x_train,y_train),(x_test,y_test)=tf.keras.datasets.mnist.load_data()
print(x_train.shape,x_test.shape)

在这里插入图片描述

#预处理---正规化
def normalize(x,y):
    x=tf.cast(x,tf.float32)
    x/=255
    return x,y

#添加一层维度,方便后续扁平化
x_train=tf.expand_dims(x_train,axis=-1)
x_test=tf.expand_dims(x_test,axis=-1)

train_dataset=tf.data.Dataset.from_tensor_slices((x_train,y_train))
test_dataset=tf.data.Dataset.from_tensor_slices((x_test,y_test))
train_dataset=train_dataset.map(normalize)
test_dataset=test_dataset.map(normalize)


#画图
plt.figure(figsize=(10,15))
i=0
for (x_test,y_test) in test_dataset.take(25):
    x_test=x_test.numpy().reshape((28,28))
    plt.subplot(5,5,i+1)
    plt.grid(False)
    plt.xticks([])
    plt.imshow(x_test,cmap=plt.cm.binary)
    plt.xlabel([y_test.numpy()],fontsize=10)
    i+=1
plt.show()

在这里插入图片描述

#开始定义模型
model=tf.keras.Sequential([
    tf.keras.layers.Flatten(input_shape=(28,28,1)),
    tf.keras.layers.Dense(128,activation=tf.nn.relu),
    tf.keras.layers.Dense(64,activation=tf.nn.relu),
    tf.keras.layers.Dense(10,activation=tf.nn.softmax)
])

model.summary()

在这里插入图片描述

# 模型编译
model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy'])

#开始训练
batch_size=32
train_dataset=train_dataset.repeat().shuffle(60000).batch(batch_size)
test_dataset=test_dataset.batch(batch_size)
#为tensorboard可视化保存数据
tensorboard_callback=tf.keras.callbacks.TensorBoard(histogram_freq=1)
model.fit(train_dataset,epochs=5,steps_per_epoch=math.ceil(60000/batch_size),callbacks=[tensorboard_callback])

在这里插入图片描述
这里本来想自己定义保存文件的路径的,但是运行就出错,没找到修改的办法只能删除自己的路径,还是用默认的路径,希望大佬们能指导一下这里该怎么修改

#模型评估
test_loss,test_accuracy=model.evaluate(test_dataset,steps=math.ceil(10000/32))
print('Accuracy on test_dataset',test_accuracy)

在这里插入图片描述

#模型预测
predictions=model.predict(test_dataset)

#查看预测结果
def plot_test(i,predictions_array,true_labels,images):
    predic,label,img=predictions_array[i],true_labels[i],images[i]
    plt.grid(False)
    plt.xticks([])
    plt.imshow(img[...,0],cmap=plt.cm.binary)
    predic_label=np.argmax(predic)
    if(predic_label==label):
        color='green'
    else:
        color='red'
    plt.xlabel("预测标签为:{},概率:{:2.0f}% (真实标签:{})".format(predic_label,100*np.max(predic),label),color=color)
    
    
def plot_value(i,predictions_array,true_label):
    predic,label=predictions_array[i],true_label[i]
    plt.grid(False)
    plt.xticks([])
    thisplot=plt.bar(range(10),predic,color='#777777')
    plt.ylim([0,1])
    predic_label=np.argmax(predic)
    thisplot[predic_label].set_color('blue')
    thisplot[label].set_color('green')
    
    
rows,cols=5,3
num_images=rows*cols
for test_images,test_labels in test_dataset.take(1):
    test_images=test_images.numpy()
    test_labels=test_labels.numpy()
    
plt.figure(figsize=(2*2*cols,2*rows))
for i in range(num_images):
    plt.subplot(rows,2*cols,2*i+1)
    plot_test(i,predictions,test_labels,test_images)
    plt.subplot(rows,2*cols,2*i+2)
    plot_value(i,predictions,test_labels)

在这里插入图片描述

总结:

总的来说,tensorflow2.0的使用还是比较简单的,没有了之前1.X的session感觉更方便,如果对1.X有的api有不懂的地方可以随时去https://www.w3cschool.cn/tensorflow_python/查找需要的帮助;
对于2.0也可以去https://tf.wiki学习。

  • 14
    点赞
  • 54
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

shelgi

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值