MXNet官方文档教程(2):基于卷积神经网络的手写数字识别示例

原本打算开始翻译计算图的部分,结果上一篇刚发完,MXNet就升级了教程文档(伤不起啊),更新了上一篇中手写数字识别示例的详细教程。那这一篇就与时俱进,来将刚更新的这篇教程翻译过来把。由于目前图片无法上传到博客中,相关图片可在原网站查看:Handwritten Digit Recognition



本教程引导你完成一个有关计算机视觉分类的应用示例:使用人工神经网络识别手写数字

 

加载数据

我们首先需要获取MNIST 数据,该数据集是手写数字识别常用的数据集。数据集中的每一幅图像都被缩放为28*28像素大小的灰度图(灰度值介于0254之间)。以下代码下载并加载图像和与图像对应的标签到numpy

import numpy as np

import os

impor turllib

import gzip

import struct

def download_data(url, force_download=True):

    fname = url.split("/")[-1]

    if force_downloadornot os.path.exists(fname):

        urllib.urlretrieve(url,fname)

    return fname

 

def read_data(label_url, image_url):

    with gzip.open(download_data(label_url))as flbl:

        magic, num = struct.unpack(">II", flbl.read(8))

        label = np.fromstring(flbl.read(), dtype=np.int8)

    with gzip.open(download_data(image_url),'rb')as fimg:

        magic, num, rows, cols = struct.unpack(">IIII", fimg.read(16))

        image = np.fromstring(fimg.read(), dtype=np.uint8).reshape(len(label), rows, cols)

    return (label, image)

 

path='http://yann.lecun.com/exdb/mnist/'

(train_lbl, train_img)= read_data(

    path+'train-labels-idx1-ubyte.gz', path+'train-images-idx3-ubyte.gz')

(val_lbl, val_img) = read_data(

    path+'t10k-labels-idx1-ubyte.gz', path+'t10k-images-idx3-ubyte.gz')

我们输出了前10幅图像和他们所对应的标签:

%matplotlib inline
import matplotlib.pyplot as plt
for i inrange(10):
    plt.subplot(1,10,i+1)
    plt.imshow(train_img[i], cmap='Greys_r')
    plt.axis('off')
plt.show()
print('label: %s'% (train_lbl[0:10],))

label: [5 0 4 1 9 2 1 3 1 4]

之后我们为MXNet创建数据迭代器。和迭代器一样,数据迭代器在每次调用next()函数时返回一批数据,包括多幅图片和其对应的标签。这些图像保存在一个大小为(batch_size, num_channels, width, height)的4维矩阵中。对于MNIST数据集来说,图像仅有一个色彩通道且高宽均为28。此外,我们经常洗乱用于训练的图像,以加快训练的速度。

import mxnet as mx
 
defto4d(img):
    return img.reshape(img.shape[0],1,28,28).astype(np.float32)/255
 
batch_size=100
train_iter= mx.io.NDArrayIter(to4d(train_img), train_lbl, batch_size, shuffle=True)
val_iter= mx.io.NDArrayIter(to4d(val_img), val_lbl, batch_size)

多层感知机

一个多层感知机包含多个全连接层。对于全连接层来说,假设输入矩阵X的大小为n*m,输出矩阵Y的大小为n*k,其中k通常被称为隐藏大小。这个层有两个参数,m*n的权重矩阵W和m*1的偏移向量b。则输出由下式得出:

Y =WX + b

全连接层的输出通常输入到一个卷积层,进行逐像素操作(elemental-wise operations)。其中一个很著名的函数就是Sigmoid函数:f(x)= 1/(1+e^(-x))。而如今人们也使用一个更简单的叫做relu函数:f(x) = max(0,x)。

最后一个全连接层通常拥有和数据集中的类别个数一样的隐藏大小。最后我们压入一个softmax层,它可以将输入映射到表示可能性的分值。同样假设输入X大小为n*m,x_i为第i行。则第i行的输出为:

定义多层感知机在MXNet中是很简单的,如下所示。

# Create a place holder variable for the input data
data= mx.sym.Variable('data')
# Flatten the data from 4-D shape (batch_size, num_channel, width, height) 
# into 2-D (batch_size, num_channel*width*height)
data= mx.sym.Flatten(data=data)
 
# The first fully-connected layer
fc1  = mx.sym.FullyConnected(data=data, name='fc1', num_hidden=128)
# Apply relu to the output of the first fully-connnected layer
act1= mx.sym.Activation(data=fc1, name='relu1', act_type="relu")
 
# The second fully-connected layer and the according activation function
fc2  = mx.sym.FullyConnected(data=act1, name='fc2', num_hidden =64)
act2= mx.sym.Activation(data=fc2, name='relu2', act_type="relu")
 
# The thrid fully-connected layer, note that the hidden size should be 10, which is the number of unique digits
fc3  = mx.sym.FullyConnected(data=act2, name='fc3', num_hidden=10)
# The softmax and loss layer
mlp  = mx.sym.SoftmaxOutput(data=fc3, name='softmax')
 
# We visualize the network structure with output size (the batch_size is ignored.)
shape= {"data" : (batch_size, 1,28,28)}
mx.viz.plot_network(symbol=mlp, shape=shape)

现在神经网络定义和数据迭代器都已经准备好了。我们可以开始训练了:

import logging
logging.getLogger().setLevel(logging.DEBUG)
 
model= mx.model.FeedForward(
    symbol = mlp,       # network structure
    num_epoch =10,     # number of data passes for training 
    learning_rate =0.1# learning rate of SGD 
)
model.fit(
    X=train_iter,       # training data
    eval_data=val_iter,# validation data
    batch_end_callback = mx.callback.Speedometer(batch_size,200)# output progress for each 200 data batches
)
INFO:root:Start training with [cpu(0)]
INFO:root:Epoch[0] Batch [200]  Speed: 26279.17 samples/sec Train-accuracy=0.111550
INFO:root:Epoch[0] Batch [400]  Speed: 27424.98 samples/sec Train-accuracy=0.111000
INFO:root:Epoch[0] Batch [600]  Speed: 27094.87 samples/sec Train-accuracy=0.133200
INFO:root:Epoch[0] Resetting Data Iterator
INFO:root:Epoch[0] Time cost=2.320
INFO:root:Epoch[0] Validation-accuracy=0.276800
INFO:root:Epoch[1] Batch [200]  Speed: 17739.48 samples/sec Train-accuracy=0.412650
INFO:root:Epoch[1] Batch [400]  Speed: 18869.69 samples/sec Train-accuracy=0.753500
INFO:root:Epoch[1] Batch [600]  Speed: 25618.04 samples/sec Train-accuracy=0.828750
INFO:root:Epoch[1] Resetting Data Iterator
INFO:root:Epoch[1] Time cost=2.988
INFO:root:Epoch[1] Validation-accuracy=0.854400
INFO:root:Epoch[2] Batch [200]  Speed: 21532.09 samples/sec Train-accuracy=0.859750
INFO:root:Epoch[2] Batch [400]  Speed: 27919.08 samples/sec Train-accuracy=0.888700
INFO:root:Epoch[2] Batch [600]  Speed: 26810.95 samples/sec Train-accuracy=0.905550
INFO:root:Epoch[2] Resetting Data Iterator
INFO:root:Epoch[2] Time cost=2.408
INFO:root:Epoch[2] Validation-accuracy=0.916300
INFO:root:Epoch[3] Batch [200]  Speed: 28097.98 samples/sec Train-accuracy=0.917300
INFO:root:Epoch[3] Batch [400]  Speed: 27490.20 samples/sec Train-accuracy=0.925850
INFO:root:Epoch[3] Batch [600]  Speed: 27937.45 samples/sec Train-accuracy=0.934900
INFO:root:Epoch[3] Resetting Data Iterator
INFO:root:Epoch[3] Time cost=2.167
INFO:root:Epoch[3] Validation-accuracy=0.938400
INFO:root:Epoch[4] Batch [200]  Speed: 26948.04 samples/sec Train-accuracy=0.942450
INFO:root:Epoch[4] Batch [400]  Speed: 24250.66 samples/sec Train-accuracy=0.943200
INFO:root:Epoch[4] Batch [600]  Speed: 22772.67 samples/sec Train-accuracy=0.951550
INFO:root:Epoch[4] Resetting Data Iterator
INFO:root:Epoch[4] Time cost=2.456
INFO:root:Epoch[4] Validation-accuracy=0.951500
INFO:root:Epoch[5] Batch [200]  Speed: 27313.59 samples/sec Train-accuracy=0.955500
INFO:root:Epoch[5] Batch [400]  Speed: 28061.48 samples/sec Train-accuracy=0.955100
INFO:root:Epoch[5] Batch [600]  Speed: 26730.32 samples/sec Train-accuracy=0.960500
INFO:root:Epoch[5] Resetting Data Iterator
INFO:root:Epoch[5] Time cost=2.206
INFO:root:Epoch[5] Validation-accuracy=0.956300
INFO:root:Epoch[6] Batch [200]  Speed: 28440.23 samples/sec Train-accuracy=0.962700
INFO:root:Epoch[6] Batch [400]  Speed: 28832.82 samples/sec Train-accuracy=0.962700
INFO:root:Epoch[6] Batch [600]  Speed: 27814.78 samples/sec Train-accuracy=0.967150
INFO:root:Epoch[6] Resetting Data Iterator
INFO:root:Epoch[6] Time cost=2.131
INFO:root:Epoch[6] Validation-accuracy=0.960300
INFO:root:Epoch[7] Batch [200]  Speed: 20942.23 samples/sec Train-accuracy=0.967550
INFO:root:Epoch[7] Batch [400]  Speed: 22264.85 samples/sec Train-accuracy=0.967750
INFO:root:Epoch[7] Batch [600]  Speed: 21294.69 samples/sec Train-accuracy=0.971500
INFO:root:Epoch[7] Resetting Data Iterator
INFO:root:Epoch[7] Time cost=2.805
INFO:root:Epoch[7] Validation-accuracy=0.961400
INFO:root:Epoch[8] Batch [200]  Speed: 17870.55 samples/sec Train-accuracy=0.972550
INFO:root:Epoch[8] Batch [400]  Speed: 11526.75 samples/sec Train-accuracy=0.971600
INFO:root:Epoch[8] Batch [600]  Speed: 15082.47 samples/sec Train-accuracy=0.974500
INFO:root:Epoch[8] Resetting Data Iterator
INFO:root:Epoch[8] Time cost=4.197
INFO:root:Epoch[8] Validation-accuracy=0.963000
INFO:root:Epoch[9] Batch [200]  Speed: 10139.52 samples/sec Train-accuracy=0.976000
INFO:root:Epoch[9] Batch [400]  Speed: 10321.69 samples/sec Train-accuracy=0.975550
INFO:root:Epoch[9] Batch [600]  Speed: 10820.23 samples/sec Train-accuracy=0.977750
INFO:root:Epoch[9] Resetting Data Iterator
INFO:root:Epoch[9] Time cost=5.777
INFO:root:Epoch[9] Validation-accuracy=0.964100

完成训练后,我们对单幅图片进行测试。

plt.imshow(val_img[0], cmap='Greys_r')
plt.axis('off')
plt.show()
prob= model.predict(val_img[0:1].astype(np.float32)/255)[0]
print'Classified as %d with probability %f'% (prob.argmax(),max(prob))

Classified as 7 with probability 0.999781

我们也可以通过给予一个数据迭代器来计算正确率。

print'Validation accuracy: %f%%'% (model.score(val_iter)*100,)

 

Validation accuracy: 96.410000%

甚至,我们可以识别写在框中的数字。

from IPython.display import HTML
import cv2
import numpy as np
from mnist_demo import html, script
def classify(img):
    img = img[len('data:image/png;base64,'):].decode('base64')
    img = cv2.imdecode(np.fromstring(img, np.uint8),-1)
    img = cv2.resize(img[:,:,3], (28,28))
    img = img.astype(np.float32).reshape((1,1,28,28))/255.0
    return model.predict(img)[0].argmax()
 
'''
To see the model in action, run the demo notebook at
https://github.com/dmlc/mxnet-notebooks/blob/master/python/tutorials/mnist.ipynb.
'''
HTML(html+ script)

卷积神经网络

注意之前的全连接层在训练时只是将图像转换为向量,而忽略了像素在水平和垂直维度上的空间信息。卷积层的作用就是通过使用一个更结构化的权重W来克服这一缺点。它使用2维卷积来代替简单的矩阵乘法来得到输出。

我们也可以使用多个特征图(每一个都拥有一个不同的权重矩阵)来提取不同的特征。

 

除了卷积层外,另一个卷积神经网络主要的变化就是加入了池化层(pooling layers)。池化层将一个n*m(通常我们称其为核大小)的图像转化为一个单独的值来降低人工神经网络对于空间位置的敏感程度(译者注:为了避免过拟合。)

data= mx.symbol.Variable('data')
# first conv layer
conv1= mx.sym.Convolution(data=data, kernel=(5,5), num_filter=20)
tanh1= mx.sym.Activation(data=conv1, act_type="tanh")
pool1= mx.sym.Pooling(data=tanh1, pool_type="max", kernel=(2,2), stride=(2,2))
# second conv layer
conv2= mx.sym.Convolution(data=pool1, kernel=(5,5), num_filter=50)
tanh2= mx.sym.Activation(data=conv2, act_type="tanh")
pool2= mx.sym.Pooling(data=tanh2, pool_type="max", kernel=(2,2), stride=(2,2))
# first fullc layer
flatten= mx.sym.Flatten(data=pool2)
fc1= mx.symbol.FullyConnected(data=flatten, num_hidden=500)
tanh3= mx.sym.Activation(data=fc1, act_type="tanh")
# second fullc
fc2= mx.sym.FullyConnected(data=tanh3, num_hidden=10)
# softmax loss
lenet= mx.sym.SoftmaxOutput(data=fc2, name='softmax')

注意上面的LeNet模型比多层感知机更加复杂,所以我们使用GPU代替CPU来进行训练。

model= mx.model.FeedForward(
    ctx = mx.gpu(0),     # use GPU 0 for training, others are same as before
    symbol = lenet,       
    num_epoch =10,     
    learning_rate =0.1)
model.fit(
    X=train_iter,  
    eval_data=val_iter,
    batch_end_callback = mx.callback.Speedometer(batch_size,200)
)

 

INFO:root:Start training with [gpu(0)]
INFO:root:Epoch[0] Batch [200]  Speed: 14804.86 samples/sec Train-accuracy=0.111500
INFO:root:Epoch[0] Batch [400]  Speed: 14294.26 samples/sec Train-accuracy=0.111000
INFO:root:Epoch[0] Batch [600]  Speed: 14273.05 samples/sec Train-accuracy=0.113600
INFO:root:Epoch[0] Resetting Data Iterator
INFO:root:Epoch[0] Time cost=4.446
INFO:root:Epoch[0] Validation-accuracy=0.113500
INFO:root:Epoch[1] Batch [200]  Speed: 14332.64 samples/sec Train-accuracy=0.141350
INFO:root:Epoch[1] Batch [400]  Speed: 14785.42 samples/sec Train-accuracy=0.777650
INFO:root:Epoch[1] Batch [600]  Speed: 14796.36 samples/sec Train-accuracy=0.914550
INFO:root:Epoch[1] Resetting Data Iterator
INFO:root:Epoch[1] Time cost=4.105
INFO:root:Epoch[1] Validation-accuracy=0.937700
INFO:root:Epoch[2] Batch [200]  Speed: 14877.08 samples/sec Train-accuracy=0.941850
INFO:root:Epoch[2] Batch [400]  Speed: 14806.53 samples/sec Train-accuracy=0.955900
INFO:root:Epoch[2] Batch [600]  Speed: 14844.79 samples/sec Train-accuracy=0.965200
INFO:root:Epoch[2] Resetting Data Iterator
INFO:root:Epoch[2] Time cost=4.048
INFO:root:Epoch[2] Validation-accuracy=0.971200
INFO:root:Epoch[3] Batch [200]  Speed: 14873.95 samples/sec Train-accuracy=0.971150
INFO:root:Epoch[3] Batch [400]  Speed: 14793.99 samples/sec Train-accuracy=0.972400
INFO:root:Epoch[3] Batch [600]  Speed: 14806.52 samples/sec Train-accuracy=0.976600
INFO:root:Epoch[3] Resetting Data Iterator
INFO:root:Epoch[3] Time cost=4.052
INFO:root:Epoch[3] Validation-accuracy=0.980600
INFO:root:Epoch[4] Batch [200]  Speed: 14428.12 samples/sec Train-accuracy=0.979100
INFO:root:Epoch[4] Batch [400]  Speed: 14298.85 samples/sec Train-accuracy=0.979550
INFO:root:Epoch[4] Batch [600]  Speed: 14618.55 samples/sec Train-accuracy=0.982400
INFO:root:Epoch[4] Resetting Data Iterator
INFO:root:Epoch[4] Time cost=4.158
INFO:root:Epoch[4] Validation-accuracy=0.983300
INFO:root:Epoch[5] Batch [200]  Speed: 14919.47 samples/sec Train-accuracy=0.983700
INFO:root:Epoch[5] Batch [400]  Speed: 14809.71 samples/sec Train-accuracy=0.984050
INFO:root:Epoch[5] Batch [600]  Speed: 14550.25 samples/sec Train-accuracy=0.986250
INFO:root:Epoch[5] Resetting Data Iterator
INFO:root:Epoch[5] Time cost=4.071
INFO:root:Epoch[5] Validation-accuracy=0.985100
INFO:root:Epoch[6] Batch [200]  Speed: 14363.59 samples/sec Train-accuracy=0.986500
INFO:root:Epoch[6] Batch [400]  Speed: 14629.87 samples/sec Train-accuracy=0.986950
INFO:root:Epoch[6] Batch [600]  Speed: 14842.83 samples/sec Train-accuracy=0.988700
INFO:root:Epoch[6] Resetting Data Iterator
INFO:root:Epoch[6] Time cost=4.113
INFO:root:Epoch[6] Validation-accuracy=0.985300
INFO:root:Epoch[7] Batch [200]  Speed: 14863.48 samples/sec Train-accuracy=0.988950
INFO:root:Epoch[7] Batch [400]  Speed: 14824.65 samples/sec Train-accuracy=0.988800
INFO:root:Epoch[7] Batch [600]  Speed: 14278.57 samples/sec Train-accuracy=0.990350
INFO:root:Epoch[7] Resetting Data Iterator
INFO:root:Epoch[7] Time cost=4.102
INFO:root:Epoch[7] Validation-accuracy=0.986400
INFO:root:Epoch[8] Batch [200]  Speed: 14875.69 samples/sec Train-accuracy=0.990300
INFO:root:Epoch[8] Batch [400]  Speed: 14833.44 samples/sec Train-accuracy=0.990750
INFO:root:Epoch[8] Batch [600]  Speed: 14804.53 samples/sec Train-accuracy=0.992250
INFO:root:Epoch[8] Resetting Data Iterator
INFO:root:Epoch[8] Time cost=4.049
INFO:root:Epoch[8] Validation-accuracy=0.987200
INFO:root:Epoch[9] Batch [200]  Speed: 14864.23 samples/sec Train-accuracy=0.992000
INFO:root:Epoch[9] Batch [400]  Speed: 14699.46 samples/sec Train-accuracy=0.991650
INFO:root:Epoch[9] Batch [600]  Speed: 14853.07 samples/sec Train-accuracy=0.992800
INFO:root:Epoch[9] Resetting Data Iterator
INFO:root:Epoch[9] Time cost=4.058
INFO:root:Epoch[9] Validation-accuracy=0.987800

注意到对于同样超参数,LeNet模型达到了98.7%的精度,高于多层感知机的96.6%

  • 2
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值