最近在了解机器学习方面的相关内容时用Java实现了一个简单的感知机。网上的很多博客在讲解感知机的时候提及了非常多的专业知识,涉及的知识面较广,对小白不是很友好,容易把人给吓跑。因此,今天就想通俗地讲一下感知机的模型,并动手写一个简单的感知机。
一、感知机的基本概念
在机器学习中,感知机是二分类的线性分类模型,属于监督学习算法。输入为实例的特征向量,输出实例的类别(去+1或者-1)。感知机对应于输入空间中将实例划分为两类的分离超平面。感知机旨在求出该超平面,可通过损失函数来求解超平面,并利用梯度下降法对损失函数进行最优化。感知机预测是用学习得到的感知机模型对新的实例进行预测的,因此属于判别模型。
上面的概念听起来有点吓人,我们通过一个简单场景来理解一下其中的各个概念。
例如,在如图的坐标系中只存在两种物品,一种是三角形,一种是五角星。这时如果放入一个未知的图形,只告诉你它的坐标值,要求你对这个图形做出预测,判断这个图形的形状,你该怎么做。由于你只有图形的坐标值,那么首先你就要根据已有的样本数据找出图形和坐标之间的关系。
接着,你可能会发现下面这条特殊的线。它刚好把二维平面中的图形分成了两个部分。三角形都位于直线f=ax+by+c的下方(也就是当你把三角形的坐标代入直线时,得到f<0),而五角星都位于直线f=ax+by+c的上方(也就是当你把五角星的坐标代入直线时,得到f>0)。
最后我们就可以根据我们发现的这条线对后续加入的未知图形进行判断。如果未知图形位于线上,我们就预测它是五角星,位于线下,我们就预测它是三角形。
在这个例子中a、b叫做权值,c叫做偏移量。权值对应事物的特征。比如我们想让机器去分辨大象和老鼠,就可以把体重,脸型,鼻子等等特征作为权值,当权值逐渐增加时平面的维数也会相应增加,机器分类的准确率一般来说也会逐渐提高。当然,简单起见,今天我们只实现二维的感知机。
现在我们重新再来看一下最前面的那段概念。输入实例的特征向量就是输入已知图形的坐标,输出实例的类别就是通过判断图形的坐标位置(在线上还是在线下)来输出实例所属的类别。(可以假定如果是五角星就输出+1,如果是三角形就输出-1)。感知机就是分类的那条线。感知机的创建重点就是要求出那条线。感知机预测就是根据求出的线对未知图形进行预测分类。
二、感知机的创建
通过前面对感知机的了解,我们知道了创建感知机的关键所在就是求出那条分离的线(分离超平面)。对于我们来说可能一眼就可以看出来了(所谓的目测大法)。但是机器要如何求解呢?
在这里,我们可以先给定一条初始的线,然后遍历已有的所有样本图形。判断这条线是否满足已有图形的分类,如果满足,则不对其做任何处理,如果不满足就调整这个线的位置(改变线的参数)重新遍历样本数据。具体如下:
初始化直线,我们随意给出一条线(位置可以低一点)。判断这条线是否满足已有样本的图形分类。代入图形1,它是三角形,应该位于直线下方,正确。代入图形2,它是三角形,应该位于直线下方,然而它却在直线上方,发生错误。这时机器会对直线的位置进行调整使其满足图形2的分类。如右图。(这里就是关键的算法所在——损失函数,到底机器要怎么去移动这条线,才能更快地得出拟合度最高的分离超平面。)。依次遍历完所有的样本数据后,我们就可以得到样本的分离线了。
理想情况如下图,至此感知机的创建就完成了。
三、java代码实现(计算部分的代码借鉴了网上的)
import java.awt.Button;
import java.awt.Color;
import java.awt.Dimension;
import java.awt.FlowLayout;
import java.awt.Graphics;
import java.awt.TextField;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Random;
import javax.swing.JFrame;
public class Feel extends JFrame{
/* 分类器参数
*/
public double[] w;//权值向量组
public double b = 0;//阈值
public double eta = 1;//学习率
ArrayList<Point> arrayList;
/**
* 初始化分类器,传入需要分组的数据
* @param arrayList 需要分类的点
*/
public Feel(ArrayList<Point> arrayList, double eta) {
// 分类器初始化
this.arrayList = arrayList;
w = new double[arrayList.get(0).x.length];//w的数量与x相同
this.eta = eta;
}
public Feel(ArrayList<Point> arrayList) {
// 分类器初始化
this.arrayList = arrayList;
w = new double[arrayList.get(0).x.length];
this.eta = 1;
}
/**
* 进行分类计算
* @return 是否分类成功
*/
public boolean Classify() {
boolean flag = false;
while (!flag) {//一旦出错,需要从头遍历
for (int i = 0; i < arrayList.size(); i++) {//所有训练集
//判断是否出错
if (LearnAnswer(arrayList.get(i)) <= 0) {
UpdateWAndB(arrayList.get(i));
//进行画面的重绘,这里不能用repaint()
this.paint(this.getGraphics());
try {
Thread.sleep(300);
} catch (InterruptedException e) {
e.printStackTrace();
}
break;
}
//如果到了最后一个还没有出错,跳出while循环
if (i == (arrayList.size() - 1)) {
flag = true;
}
}
}
System.out.println("最终结果:");
System.out.println(Arrays.toString(w));
System.out.println(b);
return true;
}
/**
* 进行学习得到的结果
* @param point 需要进行学习的点,训练样本
* @return
*/
private double LearnAnswer(Point point) {
System.out.println(Arrays.toString(w));
System.out.println(b);
return point.y * (DotProduct(w, point.x) + b);
}
/**
* 进行w更新
* @param point 需要根据样本来随机梯度下降来进行w和b更新
* @return 不需要返回值
*/
private void UpdateWAndB(Point point) {
System.out.println("结果出错更新参数");
for (int i = 0; i < w.length; i++) {
w[i] += eta * point.y * point.x[i];
}
b += eta * point.y;
return;
}
/**
* 进行点乘
* @param x1 乘数
* @param x2 乘数
* @return 点乘的积
*/
private double DotProduct(double[] x1, double[] x2) {
int len = x1.length;
double sum = 0;
for (int i = 0; i < len; i++) {
sum += x1[i] * x2[i];
}
return sum;
}
//初始化界面
public void initUI(){
this.setTitle("感知机");
this.setSize(1400,1100);
this.setDefaultCloseOperation(3);
this.setLocationRelativeTo(null);//居中对齐
this.setResizable(false);
//设置布局
FlowLayout layout=new FlowLayout(FlowLayout.CENTER,10,10);
this.setLayout(layout);
//添加按钮
Button buttonStart=new Button("开始训练");
Dimension dim1=new Dimension(80,40);
buttonStart.setPreferredSize(dim1);
this.add(buttonStart);
//添加文本框
TextField textx1=new TextField();
textx1.setPreferredSize(dim1);
TextField textx2=new TextField();
textx2.setPreferredSize(dim1);
this.add(textx1);
this.add(textx2);
//添加按钮监听机制
ButtonListener but=new ButtonListener(this,textx1,textx2);
buttonStart.addActionListener(but);
//添加预测按钮
Button buttonTest=new Button("预测图形颜色");
buttonTest.setPreferredSize(dim1);
this.add(buttonTest);
buttonTest.addActionListener(but);
//界面可视化
this.setVisible(true);
}
//重写重绘函数
public void paint(Graphics g){
super.paint(g);
System.out.println("进入paint函数");
//画出坐标轴
g.setColor(Color.black);
g.drawLine(100, 100, 100, 1100);
g.drawLine(100, 100, 1300, 100);
//System.out.println("repaint");
//画点
for(int i=0;i<arrayList.size();i++){
if(arrayList.get(i).y==1) g.setColor(Color.blue);
else g.setColor(Color.red);
//System.out.println((int)arrayList.get(i).x[0]*200+","+(int)arrayList.get(i).x[1]*200);
g.drawOval((int)arrayList.get(i).x[0]*200+100, (int)arrayList.get(i).x[1]*200+100, 15, 15);
//g.drawOval((int)arrayList.get(i).x[0]*200, (int)arrayList.get(i).x[1]*200, 20, 20);
}
//标准线的位置
//g.setColor(Color.green);
//g.drawLine(0,(int)3.3*200 , 5*200,0 );
//画线,f=w1*x+w2*y+b
int x1=0,y2=0;
int y1=(int) ((-1)*b/w[1]);
int x2=(int) ((-1)*b/w[0]);
g.setColor(Color.black);
//System.out.println(x1*200+","+y1*200+","+x2*200+","+y2*200);
g.drawLine(x1*200+100, y1*200+100, x2*200+100, y2*200+100);
//System.out.println(x1*200+","+y1*200+","+x2*200+","+y2*200);
//g.drawLine(x1*200, y1*200, x2*200, y2*200);
}
/**
* 主程序进行检测
* @param args
*/
public static void main(String[] args) {
//生成样本数据
Point p1 = new Point(new double[] { 0,1.23 }, -1);
Point p2 = new Point(new double[] { 1.32,0 }, -1);
Point p3 = new Point(new double[] { 2.18,1 }, -1);
Point p4 = new Point(new double[] { 1,2.64 }, -1);
Point p5 = new Point(new double[] { 3.5,1.2 }, 1);
Point p6 = new Point(new double[] { 1.3,3.4 }, 1);
Point p7 = new Point(new double[] { 4.3,1 }, 1);
Point p8 = new Point(new double[] { 1,4.4 }, 1);
ArrayList<Point> list = new ArrayList<Point>();
list.add(p1);
list.add(p2);
list.add(p3);
list.add(p4);
list.add(p5);
list.add(p6);
list.add(p7);
list.add(p8);
//实例化对象
Feel classifier = new Feel(list);
classifier.initUI();
//classifier.Classify();
}
}
/**
* 定义一个Point,里面包含两个部分,用来分类。x表示输入R维空间向量,y表示分类值,只有-1和+1两类
*/
class Point {
double[] x = new double[2];
double y = 0;
Point(double[] x, double y) {
this.x = x;
this.y = y;
}
public Point() {
}
}
import java.awt.TextField;
import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;
import javax.swing.JOptionPane;
public class ButtonListener implements ActionListener{
public Feel feelUI;
public TextField textx1;
public TextField textx2;
public ButtonListener(Feel feelUI){
this.feelUI=feelUI;
}
public ButtonListener(Feel feelUI,TextField textx1,TextField textx2){
this.feelUI=feelUI;
this.textx1=textx1;
this.textx2=textx2;
}
public void actionPerformed(ActionEvent e) {
// TODO Auto-generated method stub
if(e.getActionCommand().equals("开始训练")){
feelUI.Classify();
}
else if(e.getActionCommand().equals("预测图形颜色")){
//获取待测试样本的坐标信息
String stringx1=textx1.getText();
String stringx2=textx2.getText();
float x1=new Float(stringx1);//强制转化
float x2=new Float(stringx2);
if((x1*feelUI.w[0]+x2*feelUI.w[1]+feelUI.b)>=0){
JOptionPane.showMessageDialog(null,"该图形为蓝色");
}
else JOptionPane.showMessageDialog(null,"该图形为红色");
}
}
}
四、运行结果分析
我们会发现机器一直在调整分离线的参数,也就是分离直线的位置一直在改变。最后我们再拿原有的真实数据去测试当前感知机的分类准确率。
五、总结
1.分离直线不一定只有一条。
2.机器所求得的分离直线不一定或者说很难满足所有的真实样例。机器求得的分离直线可能是下面这么一条线,它会存在一些误差。而机器所求直线与真实分离直线的拟合度跟学习率以及样本的输入数据顺序等都有关。我们可以通过不断地调整学习率和更换样本数据的读入顺序来得到拟合度更高的分离直线。建议做个界面进行参数的设置。
3.数据样本的阈值c和权值w的比例一定不要太大,否则计算时间需要非常长(基本等不出来的那种)。原因就是阈值c和权值w的比例变化速度很慢。为了防止出现c和w比值过大的情况,我把x和y的取值范围设置为0-5,然后再将其扩大200倍投影到画板上。