tensorflow2.0 手写数字识别demo

tensorflow2.0 手写数字识别demo

tensorflow2.0 使用keras高级API,可以摆脱繁琐的tf.nn了。

tf.keras 用法和 keras基本相同。

通过手写数字识别CNN demo,熟悉tf.keras的基本用法。

demo 代码

#%%
import tensorflow as tf
import time

#%%
#自动加载mnist dataset
mnist = tf.keras.datasets.mnist 

#%%
#training set 60000 samples, test set 10000 samples
#labels格式0-9数字
(x_train, y_train), (x_test, y_test) = mnist.load_data() 

#%%
#normalization
x_train, x_test = x_train/255.0, x_test/255.0

#%%
print(y_train[0:20])
#%%
#building model
model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Reshape((28, 28, 1))) # 1 是 channels,model.fit()需要channels
model.add(tf.keras.layers.Conv2D(64, (5,5), activation=tf.keras.activations.relu))
model.add(tf.keras.layers.MaxPool2D(2,2))
model.add(tf.keras.layers.Conv2D(128, (3,3),activation=tf.keras.activations.relu))
model.add(tf.keras.layers.AveragePooling2D((2,2)))
model.add(tf.keras.layers.Flatten())
model.add(tf.keras.layers.Dense(512, activation=tf.keras.activations.relu))
model.add(tf.keras.layers.Dense(10, activation=tf.keras.activations.softmax))

model.compile(
    optimizer='SGD',
    loss='sparse_categorical_crossentropy', 
    #如果想使用categorical_crossentropy, 需要把labels转成onehot格式
    metrics=['accuracy']
)


#%%
#training
now = time.time()
model.fit(x_train, y_train, epochs=20) #跑20次就结束吧
#testing
model.evaluate(x_test, y_test)
print(time.time()-now)


运行结果

训练集 准确率99.00%
测试集 准确率99.03%
使用2080 maxq 训练的,每个epoch耗时5s左右。
没有过拟合,epochs设置更大或许还能继续提高准确率,不过只是为了体验一下tensorflow2.0,无所谓了。

  • 0
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值