前言
上一章,我们学习了利用逻辑回归算法处理复杂二分类问题的方法,这一章,我们通过手写数字识别的例子来介绍逻辑回归算法对多分类问题的解决办法。
问题描述
我们在写信或寄快递的时候都要填上邮编和手机号来标明邮寄的目的地和收件人,如果我们有一个程序能够自动识别出不同人手写的数字,并将其录入计算机中,无疑会大大增加邮寄的效率。然而,由于每个人的书写习惯不同,要直接找出一个适用于每个人的识别算法殊为不易。这一章我们来介绍通过逻辑回归的方法,让机器自动学习手写字的特征,进而识别其代表的数字。
首先,我们需要有一个训练集。如下图所示,吴老师的课程提供了5000个不同人手写的数字图片,并且标出了每一幅图片对应的正确数字。
这些图片和相应的数字标签就是我们的训练集,每一幅图片的分辨率都是 20 × 20 20×20 20×20的,对应于400个实数。我们将每一幅图片的这400个数字作为输入特征,再加上一个常数项1,共有401个特征。
多分类问题
与二分类问题不同,这里的输出结果是0-9共十个数字,因此处理方法也有所不同。多分类问题的处理方法通常是将其划分为多个二分类问题来处理。在本例中,我们一共有十个类别,因此就要运行十次逻辑回归算法,分别识别出某一幅图片是 0 , 1 , 2 , ⋯   , 9 0,1,2,\cdots,9 0,1,2,⋯,9的概率,取概率最大的那个数字作为最终的识别结果。
具体实现过程,我们下面结合代码来详细讲解。
首先,导入数据,可以看到,数据中有 X X X和 y y y两个变量,其中