前面的博客介绍过神经网络结构以及相关的损失函数,在这里我们通过一个简单的神经网络实现一个机器学习问题:识别手写数字图像。
和求解机器学习问题的步骤(分成学习和推理两个阶段进行)一样,使用神经网络解决问题时,也需要首先使用训练数据(学习数据)进行权重参数的学习;进行推理时,使用刚才学习到的参数,对输入数据进行分类。
1、MNIST数据集
在这里使用常用的且著名的MNIST手写数字图像数据集,下面是关于MNIST数据集的简述:
'http://yann.lecun.com/exdb/mnist/' # MNIST数据网址
'train_img':'train-images-idx3-ubyte.gz', #训练图像
'train_label':'train-labels-idx1-ubyte.gz', #训练图像的标签
'test_img':'t10k-images-idx3-ubyte.gz', #测试图像
'test_label':'t10k-labels-idx1-ubyte.gz' #测试图像的标签
MNIST数据集是由0到9的数字图像构成的(图1)。训练图像有6万张,测试图像有1万张,这些图像可以用于学习和推理。MNIST数据集的一般使用方法是,先用训练图像进行学习,再用学习到的模型度量能在多大程度上对测试图像进行正确的分类。
图1:MNIST图像数据集的例子
MNIST的图像数据是28像素 × 28像素的灰度图像(1通道),各个像素的取值在0到255之间。每个图像数据都相应地标有“7”“2”“1”等标签。下面利用Python脚本显示MNIST数据:
# coding: utf-8
import numpy as np
from PIL import Image # 图像显示模块 PIL (Python Image Library)
def img_show(img):
pil_img = Image.fromarray(np.uint8(img))
pil_img.show()
(x_train, t_train), (x_test, t_test) = load_mnist(flatten=True, normalize=False)
img = x_train[5]
label = t_train[5]
print(label) # 2
print(img.shape) # (784,)
img = img.reshape(28, 28) # 把图像的形状变为原来的尺寸
print(img.shape) # (28, 28)
img_show(img)
------------------------------------------------------------------------------------
def load_mnist(normalize=True, flatten=True, one_hot_label=False):
"""读入MNIST数据集
Parameters
----------
normalize : 将图像的像素值正规化为0.0~1.0
one_hot_label :
one_hot_label为True的情况下,标签作为one-hot数组返回
one-hot数组是指[0,0,1,0,0,0,0,0,0,0]这样的数组
flatten : 是否将图像展开为一维数组
Returns
-------
(训练图像, 训练标签), (测试图像, 测试标签)
"""
if not os.path.exists(save_file):
init_mnist()
with open(save_file, 'rb') as f:
dataset = pickle.load(f)
if normalize:
for key in ('train_img', 'test_img'):
dataset[key] = dataset[key].astype(np.float32)
dataset[key] /= 255.0
if one_hot_label:
dataset['train_label'] = _change_one_hot_label(dataset['train_label'])
dataset['test_label'] = _change_one_hot_label(dataset['test_label'])
if not flatten:
for key in ('train_img', 'test_img'):
dataset[key] = dataset[key].reshape(-1, 1, 28, 28)
return (dataset['train_img'], dataset['train_label']), (dataset['test_img'], dataset['test_label'])
# 结果显示:
2
(784,)
(28, 28)
注:load_mnist 函数以“(训练图像,训练标签),(测试图像,测试标签)”的形式返回读入的MNIST数据。此外,还可以像load_mnist(normalize=True, flatten=True, one_hot_label=False) 这 样,设 置 3 个参数。第1个参数normalize 设置是否将输入图像正规化为0.0~1.0的值。如果将该参数设置为 False ,则输入图像的像素会保持原来的0~255。第2个参数 flatten 设置是否展开输入图像(变成一维数组)。如果将该参数设置为 False ,则输入图像为1 × 28 × 28的三维数组;若设置为 True ,则输入图像会保存为由784个元素构成的一维数组。
以上,就利用MNIST数据集实现的识别手写数字图像了 !!!
参考:深度学习入门:基于Python的理论与实现