内容目录
一、MNIST 介绍1、MNIST 介绍2、获取MNIST数据的几种方法二、模型训练 1、keras 训练 2、感知机训练
3、CNN训练
一、MNIST 介绍
1、MNIST 介绍
MNIST数据集分为训练图像和测试图像。训练图像60000张,测试图像10000张,每一个图片代表0-9中的一个数字,且图片大小均为28*28的矩阵。
train-images-idx3-ubyte.gz: training set images (9912422 bytes) 训练图片
train-labels-idx1-ubyte.gz: training set labels (28881 bytes) 训练标签
t10k-images-idx3-ubyte.gz: test set images (1648877 bytes) 测试图片
t10k-labels-idx1-ubyte.gz: test set labels (4542 bytes) 测试标签
2、获取MNIST数据的几种方法
方法1
官网下载,MNIST数据集的版权在Yann LeCun教授手上,在他的主页下载即可。http://yann.lecun.com/exdb/mnist/下载4个gz文件,实际上这也是旧版TensorFlow中获取mnist的方法。注意,图像数据取值为0到1之间。方法2
谷歌 https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz 下载1个npz文件,实际上这也是新版TensorFlow中获取mnist的方法。注意,图像数据取值为0到255之间。方法3
通过TensorFlow获取,提前下载好放在这里就可以避免无法下载的问题。
#tensorflow 1.7以前
#下载好数据集,放到mnist文件夹下,可以避免无法下载的问题,然后指定datapath来读取。
from tensorflow.examples.tutorials.mnist import input_data
datapath = "./mnist/"
mnist = input_data.read_data_sets(datapath, one_hot=True)
train_x = mnist.train.images
train_y = mnist.train.labels
test_x = mnist.test.images
test_y = mnist.text.labels
#tensorflow 1.7以后
#下载好数据集mnist.npz,放于~/.keras/datasets/下,可以避免无法下载的问题
import tensorflow as tf
(train_x, train_y), (test_x, test_y) = tf.keras.datasets.mnist.load_data(path='mnist.npz')
方法4
通过Keras获取。
from keras.datasets import mnist
(train_x, train_y), (test_x, test_y) = mnist.load_data()
#:\Program Files\Python\Python36-64\Lib\site-packages\keras\datasets
# C:\Users\user_name\.keras\datasets
二、模型训练
1、获数据集
import tensorflow as tf
import time
import matplotlib
import matplotlib.pyplot as plt
start = time.time()
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
#给定的像素的灰度值在0-255,为了使模型的训练效果更好,通常将数值归一化映射到0-1。
print((x_train.shape, y_train.shape), (x_test.shape, y_test.shape))
#((60000, 28, 28), (60000,)) ((10000, 28, 28), (10000,))
2、查看其中图片
some_digit = x_train[3000]
some_digit_image = some_digit.reshape(28, 28)
plt.imshow(some_digit_image, cmap=matplotlib.cm.binary, interpolation="nearest") #灰色的
#plt.imshow(some_digit_image) #彩色的
plt.axis('off')
plt.show()
from keras.datasets import mnist #这里是从keras的datasets中导入mnist数据集
import matplotlib.pyplot as plt #这里是将matplo