(Tensorflow学习)MNIST手写体识别 Keras实现

用原生的tensorflow搭建网络有点繁琐,这边用keras就方便了很多

实现一个全连接网络 隐藏层节点500个

from keras.models import Sequential
from keras.layers import Dense,Activation
from keras.losses import categorical_crossentropy
from keras.optimizers import Adadelta
from tensorflow.examples.tutorials.mnist import input_data

model=Sequential()
model.add(Dense(500,activation='relu',input_shape=[784]))
model.add(Dense(10,activation='softmax'))
model.compile(loss=categorical_crossentropy,optimizer=Adadelta(),metrics=['accuracy'])

batch_size=100
epochs=8

mnist=input_data.read_data_sets("D:/mnist",one_hot=True)
train_X=mnist.train.images
train_Y=mnist.train.labels
model.fit(train_X,train_Y,batch_size=batch_size,epochs=epochs)

test_X=mnist.test.images
test_Y=mnist.test.labels
loss,accuracy=model.evaluate(test_X,test_Y,verbose=1)
print('loss:%.4f accuracy:%.4f'%(loss,accuracy))

 

 

实现CNN

from keras.models import Sequential
from keras.layers import Conv2D, MaxPool2D, Flatten, Dropout, Dense
from keras.losses import categorical_crossentropy
from keras.optimizers import Adadelta
from tensorflow.examples.tutorials.mnist import input_data

model=Sequential()
model.add(Conv2D(32,(5,5),activation='relu',input_shape=[28,28,1]))
model.add(MaxPool2D(pool_size=(2,2),strides=[2,2],padding='SAME'))
model.add(Conv2D(64,(5,5),activation='relu'))
model.add(MaxPool2D(pool_size=(2,2),strides=[1,1],padding='SAME'))
model.add(Flatten())
model.add(Dropout(0.5))
model.add(Dense(128,activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(10,activation='softmax'))
model.compile(loss=categorical_crossentropy,optimizer=Adadelta(),metrics=['accuracy'])

batch_size=100
epochs=8

mnist=input_data.read_data_sets("D:/mnist",one_hot=True)
train_X=mnist.train.images
train_Y=mnist.train.labels
train_X_trans=train_X.reshape(-1,28,28,1)
model.fit(train_X_trans,train_Y,batch_size=batch_size,epochs=epochs)

test_X=mnist.test.images
test_Y=mnist.test.labels
test_X_trans=test_X.reshape(-1,28,28,1)
loss, accuracy = model.evaluate(test_X_trans, test_Y, verbose=1)
print('loss:%.4f accuracy:%.4f' %(loss, accuracy))

keras方便很多

OK

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值