本文主要记录如何使用最简单的人工神经网络求解二分类问题。
先决条件:
- 问题是线性可分的。
- 训练数据集是线性可分的。
基础知识
需要我们分类的样本是什么样子呢?来看看下面这幅最简单的二分类图
事物特征
上图中,x1 与 x2 代表了事物的两个属性,例如在老虎和狮子的分类问题中,x1 和 x2 可以分别代表 “前掌宽度” 和 “体型大小” 两个属性(这里只是举例,实际情况可能和示例图有差别),这些属性可以用来区分老虎和狮子,我们称他们为特征值。如何提取事物特征为特征工程领域问题,这里不做过多阐述。
线性可分
线性可分定义非常简单:对于样本集,存在一条直线、一个平面或一个超平面可以将两类事物分割开来,则称线性可分。如上图所示,class1 与 class2 可以使用一条直线分割开来。
分割面公式
对于一条直线,我们需要一个点来分割两类;对于二维空间,我们需要一条直线来分割;对于三维空间,我们需要一个二维平面来分割;对于四维空间,我们需要一个三维的超平面来分割;依此类推,更高维度的空间,都需要找一个比它低一维的超平面才能进行分割,线性可分的分类问题即是如此,我们需要找到一个超平面来将不同的分类分隔开,那么这个平面的方程为:Ax+By+Cz+D=0(平面方程的一般式),其中A、B、C、D均为常量且不同时为0,类比可得,更高维的超平面将是w1x+w2x+…+wnx+b=0。
对于上图问题,ax1+bx2+c=0 则是将 class1 与 class2 分割开来的一个超平面。
简单人工神经网络(ANN)模型
公式
这个公式其实就是分割面公式。
输入 X
上图则为一个最简的ANN模型,其中 x1~x3 为输入 ,这些输入分别代表了问题的某个特征,例如,在狮子和老虎的分类问题中,x1…xn 可以分别代表体毛程度、体重、纹路、奔跑速度、咬合力…,这些变量我们统一定义为 X 并且都为数值,因为只有有了相应的数值,才能量化特征。一般来说,这些特性都是足够用于区别不同分类,若某个特征在这个分类问题中散乱分布,例如站立时的形状轮廓
对于上两幅轮廓图,我们可能并不能找到合适的方法来比较准确的区分狮子和老虎,这种特征则为噪声。
权重 W 与偏置 b
其中的 w1~wn 为权重变量,统一定义为 W,可以简单理解为其对应特征的影响程度的描述,就如同我们判断老虎和狮子时通常是通过对比体毛和纹路来得到答案,我们更关注与这两个特征,因此我们大概可以判定,W 在这两个特征上的绝对值将会偏大。
其中的 b 为偏置,它是为了更好的拟合数据,有兴趣的可自行百度。
输出 y
y 为神经网络最终得到的结果。在二分类问题中,我们通常将 y 定义为:当 y>=0时,y=1,否则 y=-1。理解起来并不困难,因为在二分类问题中,目标只可能为两个类别中的某一类,为了方便观察,以 0 为界限,最终得到正数则为类别 1,得到负数则为类别 2。
训练样本
我们学习知识通常需要一个目标以及一些已知的知识储备,机器学习一样如此,需要输入训练样本,机器才能知道需要学习的东西。我们将训练样本定义为(X,Y),其中 X 为多个特征值(x1~xn),Y 则为这个样本的类别(或标签),在二分类问题中,Y 一般为 1 或 -1 来表示不同类别。
需要学习的参数
模型中,X 需要我们手动输入,因此 X 是为已知的,y 为最终计算得到的结果,因此 y 也并不是需要学习的,权重 W 与偏置 b 没有从任何地方可以得到,因此它们就是需要学习的参数。
学习方法
现在我们知道了模型的计算过程、结果定义以及需要学习的参数,那么通过什么手段来学习 W 和 b 呢?这里选择使用 感知器算法 对模型中的 W、b 进行更新。它是一种最简单的学习算法,是一种二元线性分类器,具体公式为
学习步骤:
- 在训练样本集中选择一个样本(X,Y),将这个样本输入神经网络模型
- 当样本类别 Y=1,且计算结果 y<=0 时,更新权重 W 与 b;当样本类别 Y=-1,且计算结果 y>=0 时,更新权重 W 与 b
- 重复步骤 1,2,直到对于训练样本集中所有的样本都不满足上述公式,则为模型训练完成
模型测试
准备一组测试样本集,测试样本与训练样本格式一致。将测试样本的 X 输入到训练好的模型中,模型计算得到 y,如 y 与 样本类别 Y 符号一致(同为正数或同为负数),则分类成功,反之则为失败,成功率越高则说明模型训练的越好。当然对于相同的训练样本集,反复的训练最终得到的模型都是一样的。
Java实现
模型与学习算法都确定好了之后便可以进行编码实现了
- 首先定义样本数据类型 Data.java
public class Data {
// 特征值
private double[] x;
// 标签
private int y;
public Data(double[] x, int y) {
this.x = x;
if (y >= 0) {
this.y = 1;
} else {
this.y = -1;
}
}
publi