使用MNIST数据集对0到9之间的数字进行手写数字识别是神经网络的一个典型入门教程。
该技术在现实场景中是很有用的,比如可以把该技术用来扫描银行转帐单或支票,其中帐号和需要转账的金额可以被识别处理并写在明确定义的方框中。
在本教程中,我们将介绍如何使用Julia编程语言和名为Flux的机器学习库来实现这一技术。
为什么使用Flux和Julia?
本教程为什么想使用Flux(https://fluxml.ai/) 和Julia(https://julialang.org/) ,而不是像Torch、PyTorch、Keras或TensorFlow 2.0这样的知名框架呢?
一个很好的原因是因为Flux更易于学习,而且它提供更好的性能和拥有有更大的潜力,另外一个原因是,Flux在仍然是一个小库的情况下实现了很多功能。Flux库非常小,因为它所做的大部分工作都是由Julia编程语言本身提供的。
例如,如果你查看Gorgonia ML库(https://github.com/gorgonia/gorgonia) 中的Go编程语言,你将看到,它明确地展示了其他机器学习库如何构建一个需要执行和区分的表达式图。在Flux中,这个图就是Julia本身。Julia与LISP非常相似,因为Julia代码可以很容易地表示为数据结构,可以对其进行修改和计算。
机器学习概论
如果你是机器学习的新手,你可以跟着本教程来学习,但并不是所有的东西对你来说都是有价值的。你也可以看看我以前关于Medium的一些文章,它们可能会解释你一些新手的疑惑:
线性代数的核心思想。(https://medium.com/@Jernfrost/the-core-idea-of-linear-algebra-7405863d8c1d)
-
线性代数基本上是关于向量和矩阵的,这是你在机器学习中经常用到的东西。
使用引用。(https://medium.com/@Jernfrost/working-with-and-emulating-references-in-julia-e02c1cae5826)
-
它看起来有点不太好理解,但是如果你想理解像Flux这样的ML库,那么理解Julia中的引用是很重要的。
Flux的实现。(https://medium.com/@Jernfrost/implementation-of-a-modern-machine-learning-library-3596badf3be)
-
如何实现Flux-ML库的初学者指南。
机器学习简介。(https://medium.com/@Jernfrost/machine-learning-for-dummies-in-julia-6cd4d2e71a46) 机器学习概论。
简单多层感知机
我们要编程的人工神经网络被称为简单的多层感知机,这是神经网络(ANN)的基础,大多数教科书都会从它开始。
我先展示整个程序,然后我们再更详细地讲解不同的部分。
using Flux, Flux.Data.MNIST, Statistics
using Flux: onehotbatch, onecold, crossentropy, throttle
using Base.Iterators: repeated
# Load training data. 28x28 grayscale images of digits
imgs = MNIST.images()
# Reorder the layout of the data for the ANN
imagestrip(image::Matrix{<:Gray}) = Float32.(reshape(image, :))
X = hcat(imagestrip.(imgs)...)
# Target output. What digit each image represents.
labels = MNIST.labels()
Y = onehotbatch(labels, 0:9)
# Defining the model (a neural network)
m = Chain(
Dense(28*28, 32, relu),
Dense(32, 10),
softmax)
loss(x, y) = crossentropy(m(x), y)
dataset = repeated((X, Y), 200)
opt = ADAM()
evalcb = () -> @show(loss(X, Y))
# Perform training on data
Flux.train!(loss, params(m), dataset, opt, cb = throttle(evalcb, 10))
探索输入数据
数据预处理通常是数据科学中最大的工作之一。通常情况下,数据的组织或格式化方式与将其输入算法所需的方式不同。
我们首先将MNIST数据集加载为60000个28x28像素的灰度图像: