机器学习之感知机

    本文主要记录本人在学习机器学习过程中的相关代码实现,参考《机器学习实战》

package perceptron;

import java.awt.BorderLayout;
import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;

import javax.swing.*;

public class PerceptronFrame extends JFrame {

	protected int dim = 1;// 输入空间特征维数
	protected int num = 1;// 训练数据个数
	double[][] xData;// 存储训练数据
	double[] yData;// 存储训练数据的类型
	int i_n = 0;// 遍历num
	int i_d = 0;// 遍历dim
	protected double learn = 0;// 学习率
	double[] w;//原始训练
	double[] alpha;//对偶训练
	double b;
	double[][] gram;
	private static final int DEFAULT_WIDTH = 600;
	private static final int DEFAULE_HEIGHT = 400;
	private JPanel panel1;
	private InputPanel panel2;
	private ButtonPanel pabel3;
	private JTextArea textArea;
	private JTextField field2;
	public JTextField fieldX;
	public JTextField fieldY;
	public JButton addButton;
	public JTextField field3;
	public JButton original;
	public JButton dual;
	public JButton confirmButton;
	public JTextField field1;
	public JTextField inputField;
	public JButton predictButton;

	public PerceptronFrame() {
		dim = 1;// 输入空间特征维数
		num = 1;// 训练数据个数
		xData=null;// 存储训练数据
		yData=null;// 存储训练数据的类型
		i_n = 0;// 遍历num
		i_d = 0;// 遍历dim
		learn = 0;// 学习率
		w=null;//原始训练
		alpha=null;//对偶训练
		b=0;
		gram=null;
		
		setTitle("感知机学习");
		setSize(DEFAULT_WIDTH, DEFAULE_HEIGHT);
		panel1 = new JPanel();
		textArea = new JTextArea("请点击下面的按钮进行操作", 10, 52);
		textArea.setEditable(false);
		textArea.setLineWrap(true);
		JScrollPane scrollPane = new JScrollPane(textArea);
		panel1.add(scrollPane);
		add(panel1, BorderLayout.NORTH);

		panel2 = new InputPanel();
		add(panel2, BorderLayout.CENTER);

		pabel3 = new ButtonPanel();
		add(pabel3, BorderLayout.SOUTH);
	}

	public class InputPanel extends JPanel {

		public InputPanel() {
			setSize(550, 100);
			JLabel labelx = new JLabel("x:");
			fieldX = new JTextField(10);
			JLabel labely = new JLabel("y:");
			fieldY = new JTextField(10);
			JLabel label1 = new JLabel("输入空间特征维数:");
			field1 = new JTextField(10);
			JLabel label2 = new JLabel("输入训练数据组数:");
			field2 = new JTextField(10);
			field2.setEditable(false);

			setLayout(null);
			add(labelx);
			labelx.setBounds(10, 10, 30, 20);
			add(fieldX);
			fieldX.setBounds(30, 10, 100, 20);
			fieldX.setEditable(false);
			add(labely);
			labely.setBounds(10, 40, 30, 20);
			add(fieldY);
			fieldY.setBounds(30, 40, 100, 20);
			fieldY.setEditable(false);
			add(label1);
			label1.setBounds(200, 10, 150, 20);
			add(field1);
			field1.setBounds(320, 10, 100, 20);
			add(label2);
			label2.setBounds(200, 40, 150, 20);
			add(field2);
			field2.setBounds(320, 40, 100, 20);

			addButton = new JButton("添加");
			addButton.setEnabled(false);
			confirmButton = new JButton("确定");
			confirmButton.setEnabled(false);
			JLabel label3 = new JLabel("输入学习率:");
			field3 = new JTextField();
			field3.setEditable(false);
			original = new JButton("原始形式");
			original.setEnabled(false);
			dual = new JButton("对偶形式");
			dual.setEnabled(false);
			field1.addActionListener(new ActionListener() {

				@Override
				public void actionPerformed(ActionEvent e) {
					// TODO Auto-generated method stub
					dim = Integer.parseInt(field1.getText());
					textArea.append("\n空间特征维数为:" + dim);
					field1.setEditable(false);
					field2.setEditable(true);
				}
			});

			field2.addActionListener(new ActionListener() {

				@Override
				public void actionPerformed(ActionEvent e) {
					// TODO Auto-generated method stub
					num = Integer.parseInt(field2.getText());
					textArea.append("\n训练数据个数为:" + num);
					field2.setEditable(false);
					fieldX.setEditable(true);
					fieldY.setEditable(true);
					addButton.setEnabled(true);
					xData = new double[num][dim];
					yData = new double[num];
				}
			});

			addButton.addActionListener(new ActionListener() {

				@Override
				public void actionPerformed(ActionEvent e) {
					// TODO Auto-generated method stub

					if (i_n < num) {
						String[] s = fieldX.getText().split(",");
						yData[i_n] = Double.parseDouble(fieldY.getText());
						textArea.append("\n(" + i_n + "): y=" + yData[i_n]
								+ "	");
						for (i_d = 0; i_d < dim; i_d++) {
							xData[i_n][i_d] = Double.parseDouble(s[i_d]);
							textArea.append("x" + i_d + "=" + xData[i_n][i_d]
									+ "	");
						}
						i_n++;
					}
					if (i_n >= num) {
						i_n = 0;
						fieldX.setEditable(false);
						fieldY.setEditable(false);
						addButton.setEnabled(false);
						field3.setEditable(true);
					}
				}
			});

			add(addButton);
			addButton.setBounds(20, 80, 60, 25);
			add(confirmButton);
			confirmButton.setBounds(100, 80, 60, 25);
			add(label3);
			label3.setBounds(180, 80, 70, 25);
			add(field3);
			field3.setBounds(260, 80, 60, 25);
			add(original);
			original.setBounds(340, 80, 100, 25);
			add(dual);
			dual.setBounds(450, 80, 100, 25);

			field3.addActionListener(new ActionListener() {

				@Override
				public void actionPerformed(ActionEvent e) {
					// TODO Auto-generated method stub
					learn = Double.parseDouble(field3.getText());
					textArea.append("\n学习率为:" + learn);
					field3.setEditable(false);
					original.setEnabled(true);
					dual.setEnabled(true);
				}
			});
			ActionListener originalActionListener = new OriginalActionListener();
			original.addActionListener(originalActionListener);
			ActionListener dualActionListener = new DualActionListener();
			dual.addActionListener(dualActionListener);
		}
	}

	public class ButtonPanel extends JPanel {
		private JButton clearButton;
		private JButton resetButton;

		public ButtonPanel() {
			clearButton = new JButton("清空");
			resetButton = new JButton("重置");
			inputField = new JTextField("输入预测数据");
			predictButton = new JButton("预测");
			add(clearButton);
			add(resetButton);
			add(inputField);
			add(predictButton);
			inputField.setEditable(false);
			predictButton.setEnabled(false);
			ActionListener predictActionListener=new PredictActionListener();
			predictButton.addActionListener(predictActionListener);
			clearButton.addActionListener(new ActionListener() {
				
				@Override
				public void actionPerformed(ActionEvent e) {
					// TODO Auto-generated method stub
					textArea.setText("");
				}
			});
			resetButton.addActionListener(new ActionListener() {
				
				@Override
				public void actionPerformed(ActionEvent e) {
					// TODO Auto-generated method stub
					dim = 1;// 输入空间特征维数
					num = 1;// 训练数据个数
					xData=null;// 存储训练数据
					yData=null;// 存储训练数据的类型
					i_n = 0;// 遍历num
					i_d = 0;// 遍历dim
					learn = 0;// 学习率
					w=null;//原始训练
					alpha=null;//对偶训练
					b=0;
					gram=null;
					
					field1.setEditable(true);
					field1.setText("");
					field2.setText("");
					field2.setEditable(false);
					fieldX.setText("");
					fieldX.setEditable(false);
					fieldY.setText("");
					fieldY.setEditable(false);
					field3.setText("");
					field3.setEditable(false);
					addButton.setEnabled(false);
					confirmButton.setEnabled(false);
					field3.setEditable(false);
					original.setEnabled(false);
					dual.setEnabled(false);
					inputField.setEditable(false);
					predictButton.setEnabled(false);
				}
			});
		}
	}

	public class OriginalActionListener implements ActionListener {

		@Override
		public void actionPerformed(ActionEvent e) {
			// TODO Auto-generated method stub
			int i = 0;// 记录迭代次数
			textArea.append("\n迭代次数\t误分类点\tw\tb\tw*x+b");
			textArea.append("\n" + (i++) + "\t初始状态\t0\t0\t0");
			w = new double[dim];
			b = 0;
			boolean flag = true;// flag为true表示存在误分类点
			while (flag) {
				flag = false;
				for (i_n = 0; i_n < num; i_n++) {
					if (yData[i_n] * multiply(w, xData[i_n], b) <= 0) {
						flag = true;
						b = b + learn * yData[i_n];
						textArea.append("\n" + (i++) + "	(" + i_n + ")" + "\t(");
						for (i_d = 0; i_d < dim; i_d++) {
							w[i_d] = w[i_d] + learn * xData[i_n][i_d]
									* yData[i_n];
							if (i_d < (dim - 1)) {
								textArea.append(w[i_d] + ",");
							} else {
								textArea.append(w[i_d] + ")");
							}
						}
						textArea.append("\t" + b + "\t");
						for (i_d = 0; i_d < dim; i_d++) {
							if ((i_d + 1) < dim && (w[i_d + 1] >= 0)) {
								textArea.append(+w[i_d] + "x" + i_d + "+");
							} else {
								textArea.append(+w[i_d] + "x" + i_d);
							}
						}
						if (b >= 0) {
							textArea.append("+" + b);
						} else {
							textArea.append("" + b);
						}
						break;
					}
				}
			}
			inputField.setEditable(true);
			predictButton.setEnabled(true);
		}
	}

	public class DualActionListener implements ActionListener {

		@Override
		public void actionPerformed(ActionEvent e) {
			// TODO Auto-generated method stub
			//初始化gram矩阵
			gram=new double[num][num];
			alpha=new double[num];
			b=0;
			for(int i=0;i<num;i++){
				for(int j=0;j<num;j++){
					for(i_d=0;i_d<dim;i_d++){
						gram[i][j]=gram[i][j]+xData[i][i_d]*xData[j][i_d];
					}
				}
			}
			boolean flag = true;//flag为true表示存在误分类点
			int j=0;//记录迭代次数
			textArea.append("\n迭代次数\t");
			textArea.append("误分类点\t");
			for(int i=0;i<num;i++){
				textArea.append("alpha"+i+"\t");
			}
			textArea.append("b");
			textArea.append("\n"+(j++)+"\t初始状态\t0\t0\t0\t");
			
			while(flag){
				flag=false;
				for(i_n=0;i_n<num;i_n++){
					if(yData[i_n]*multiply2(alpha, gram, b,i_n)<=0){
						flag=true;
						alpha[i_n]=alpha[i_n]+learn;
						b=b+yData[i_n]*learn;
						textArea.append("\n"+(j++)+"\tx"+i_n+"\t");
						for(int i=0;i<num;i++){
							textArea.append(alpha[i]+"\t");
						}
						textArea.append(b+"");
						break;
					}
				}
			}
			inputField.setEditable(true);
			predictButton.setEnabled(true);
		}
	}
	
	public class PredictActionListener implements ActionListener {

		@Override
		public void actionPerformed(ActionEvent e) {
			// TODO Auto-generated method stub
			String[] s=inputField.getText().split(",");
			double[] x=new double[s.length];
			for(int i=0;i<dim;i++){
				x[i]=Double.parseDouble(s[i]);
			}
			if(multiply(w, x, b) <= 0){
				textArea.append("\n输入的x是反例");
			}else{
				textArea.append("\n输入的x是正例");
			}
		}

	}

	public double multiply(double[] w2, double[] ds, double b2) {
		// TODO Auto-generated method stub
		double sum = 0;
		for (i_d = 0; i_d < dim; i_d++) {
			sum = sum + w2[i_d] * ds[i_d];
		}
		return sum + b2;
	}

	public double multiply2(double[] alpha2, double[][] gram2, double b2,
			int i_n2) {
		// TODO Auto-generated method stub
		double sum=0;
		for(int i=0;i<num;i++){
			sum=sum+alpha2[i]*yData[i]*gram2[i_n2][i];
		}
		return sum+b2;
	}
}

package perceptron;

import java.awt.*;

import javax.swing.*;

public class PerceptronTest {

	public static void main(String[] args) {
		// TODO Auto-generated method stub
		EventQueue.invokeLater(new Runnable() {
			
			@Override
			public void run() {
				// TODO Auto-generated method stub
				JFrame frame=new PerceptronFrame();
				frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
				frame.setVisible(true);
			}
		});
	}

}



评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值