利用CNN对MNIST数据集识别,keras框架

原创 2018年04月16日 18:49:25

1、数据集下载

https://blog.csdn.net/wuzhichenggo/article/details/79332128

2、代码执行

from tensorflow.examples.tutorials.mnist import input_data

#导入数据
dir = r'F:\dataset\mnist'
mnist = input_data.read_data_sets(dir,one_hot=True)

#输出mnist大小
print(mnist.train.images.shape,mnist.train.labels.shape)
print(mnist.test.images.shape,mnist.test.labels.shape)
print(mnist.validation.images.shape,mnist.validation.labels.shape)

#将数据重组变成图像格式输出,图像大小28*28
import numpy as np
train_data = mnist.train.images.reshape(55000,28,28,1)
train_labels = mnist.train.labels
test_data = mnist.test.images.reshape(10000,28,28,1)
test_labels = mnist.test.labels
#validation_data = mnist.validation.images.reshape(5000,28,28,1)
#validation_labels = mnist.validation.labels

#导入keras
from keras.preprocessing.image import ImageDataGenerator
from keras.models import Sequential
from keras.layers.core import Dense, Dropout, Activation, Flatten
from keras.layers.advanced_activations import PReLU
from keras.layers.convolutional import Convolution2D, MaxPooling2D
from keras.optimizers import SGD, Adadelta, Adagrad
from keras.utils import np_utils, generic_utils
from six.moves import range

#建立CNN
#生成一个model
model = Sequential()

#第一个卷积层,4个卷积核,每个卷积核大小5*5。1表示输入的图片的通道,灰度图为1通道。
#border_mode可以是valid或者full,具体看这里说明:http://deeplearning.net/software/theano/library/tensor/nnet/conv.html#theano.tensor.nnet.conv.conv2d
#激活函数用relu
#你还可以在model.add(Activation('tanh'))后加上dropout的技巧: model.add(Dropout(0.5))
model.add(Convolution2D(4,5,5,input_shape=(28,28,1)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2,2)))

#第二个卷积层,8个卷积核,每个卷积核大小3*3。4表示输入的特征图个数,等于上一层的卷积核个数
#激活函数用relu
#采用maxpooling,poolsize为(2,2)
model.add(Convolution2D(8, 3, 3))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))

#第三个卷积层,16个卷积核,每个卷积核大小3*3
#激活函数用tanh
#采用maxpooling,poolsize为(2,2)
model.add(Convolution2D(16, 3, 3))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))

#全连接层,先将前一层输出的二维特征图flatten为一维的。
#全连接有128个神经元节点,初始化方式为normal
model.add(Flatten())
model.add(Dense(128))
model.add(Activation('relu'))
model.add(Dropout(0.5))

#最后一层softmax,输出是10个十个类别
model.add(Dense(10))
model.add(Activation('softmax'))

#开始训练模型
#使用SGD+momentum
#model.compile里的参数loss就是损失函数(目标函数)
model.compile(optimizer = 'rmsprop',
              loss='categorical_crossentropy',
              metrics = ['accuracy'])

model.fit(train_data, train_labels,
          nb_epoch=10, batch_size=100,
          validation_data=(test_data, test_labels))

参考:
https://blog.csdn.net/shizhengxin123/article/details/72383728

3、结果显示

这里写图片描述

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/u014135752/article/details/79964634

03-Keras之用MNIST数据集训练一个CNN

03-Keras之用MNIST数据集训练一个CNN模型code# -*- coding: utf-8 -*-'''Trains a simple convnet on the MNIST datase...
  • sinat_25059791
  • sinat_25059791
  • 2017年02月19日 18:31
  • 740

利用keras(tensorflow) 做cnn mnist识别

keras图像数据处理以及图像识别小例子 1、数据预处理 数据集请自行下载,数据不大,20来兆 数据具体如下所示: 格式为 要识别的数字.序号.jpg 数据预处理代码,我用的是tens...
  • shizhengxin123
  • shizhengxin123
  • 2017年05月17日 14:42
  • 4209

DeepLearning&Keras学习笔记3__mnist数据集CNN

1.Introduction利用卷积神经网络CNN对Mnist数据集手写数字进行分类。2.Source code#encoding:utf-8 '''Trains a simple convnet o...
  • Mr_KkTian
  • Mr_KkTian
  • 2017年08月13日 20:25
  • 399

keras/构建卷积神经网络识别mnist

环境:Keras 2.04, python 2.7,GPU使用深度学习框架keras,构建卷积神经网络识别手写数字,keras在构建神经网络方面比Tensorflow简单很多,而且Tensorflow...
  • szj_huhu
  • szj_huhu
  • 2017年07月11日 09:18
  • 671

tensorflow 使用CNN 进行mnist数据集识别

一、CNN的引入 在人工的全连接神经网络中,每相邻两层之间的每个神经元之间都是有边相连的。当输入层的特征维度变得很高时,这时全连接网络需要训练的参数就会增大很多,计算速度就会变得很慢,例如一张黑白的...
  • u011808673
  • u011808673
  • 2017年11月12日 12:32
  • 307

[keras]用CNN来刷Kaggle的digit手写数据集比赛

keras实战keras是比较适合新手的深度框架,不说废话,一切代码全是手敲,有问题可以留言共同交流。1.处理数据从kaggle官方下载数据集,地址:https://www.kaggle.com/c/...
  • neruda1991
  • neruda1991
  • 2017年12月07日 20:46
  • 150

基于MNIST数据集的深度学习库keras的学习

基于MNIST数据集的深度学习库keras的学习目录基于MNIST数据集的深度学习库keras的学习 目录 学习步骤 搭建简单模型训练预测 先上代码如下 训练的结果如下 搭建CNN模型训练预测 还是先...
  • oQiCheng1234567
  • oQiCheng1234567
  • 2017年05月04日 10:36
  • 387

使用libsvm对MNIST数据集进行实验

在学SVM中的实验环节,老师介绍了libsvm的使用。当时看完之后感觉简单的说不出话来。 1. libsvm介绍 虽然原理要求很高的数学知识等,但是libsvm中,完全就是一个工具包,拿来就能用。当时...
  • arthur503
  • arthur503
  • 2014年02月26日 13:36
  • 6955

Keras 深度学习框架Python Example:CNN/mnist

Keras是基于Theano的一个深度学习框架,它的设计参考了Torch,用Python语言编写,是一个高度模块化的神经网络库,支持GPU和CPU。...
  • Eric_Wilson
  • Eric_Wilson
  • 2015年11月15日 15:39
  • 10272

使用Keras搭建一个CNN处理MNIST数据

代码结构这里主要是参考了github上面一个u-net的程序的结构。链接 在项目中,程序员将整个神经网络分成了网络结构和训练两个类,并定义了一些函数来完成类似混淆矩阵生成这样的操作。 在这里,也是...
  • monsterhoho
  • monsterhoho
  • 2017年06月20日 10:04
  • 1054
收藏助手
不良信息举报
您举报文章:利用CNN对MNIST数据集识别,keras框架
举报原因:
原因补充:

(最多只允许输入30个字)