本文主要记录本人在学习机器学习过程中的相关代码实现,参考《机器学习实战》
package perceptron;
import java.awt.BorderLayout;
import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;
import javax.swing.*;
public class PerceptronFrame extends JFrame {
protected int dim = 1;// 输入空间特征维数
protected int num = 1;// 训练数据个数
double[][] xData;// 存储训练数据
double[] yData;// 存储训练数据的类型
int i_n = 0;// 遍历num
int i_d = 0;// 遍历dim
protected double learn = 0;// 学习率
double[] w;//原始训练
double[] alpha;//对偶训练
double b;
double[][] gram;
private static final int DEFAULT_WIDTH = 600;
private static final int DEFAULE_HEIGHT = 400;
private JPanel panel1;
private InputPanel panel2;
private ButtonPanel pabel3;
private JTextArea textArea;
private JTextField field2;
public JTextField fieldX;
public JTextField fieldY;
public JButton addButton;
public JTextField field3;
public JButton original;
public JButton dual;
public JButton confirmButton;
public JTextField field1;
public JTextField inputField;
public JButton predictButton;
public PerceptronFrame() {
dim = 1;// 输入空间特征维数
num = 1;// 训练数据个数
xData=null;// 存储训练数据
yData=null;// 存储训练数据的类型
i_n = 0;// 遍历num
i_d = 0;// 遍历dim
learn = 0;// 学习率
w=null;//原始训练
alpha=null;//对偶训练
b=0;
gram=null;
setTitle("感知机学习");
setSize(DEFAULT_WIDTH, DEFAULE_HEIGHT);
panel1 = new JPanel();
textArea = new JTextArea("请点击下面的按钮进行操作", 10, 52);
textArea.setEditable(false);
textArea.setLineWrap(true);
JScrollPane scrollPane = new JScrollPane(textArea);
panel1.add(scrollPane);
add(panel1, BorderLayout.NORTH);
panel2 = new InputPanel();
add(panel2, BorderLayout.CENTER);
pabel3 = new ButtonPanel();
add(pabel3, BorderLayout.SOUTH);
}
public class InputPanel extends JPanel {
public InputPanel() {
setSize(550, 100);
JLabel labelx = new JLabel("x:");
fieldX = new JTextField(10);
JLabel labely = new JLabel("y:");
fieldY = new JTextField(10);
JLabel label1 = new JLabel("输入空间特征维数:");
field1 = new JTextField(10);
JLabel label2 = new JLabel("输入训练数据组数:");
field2 = new JTextField(10);
field2.setEditable(false);
setLayout(null);
add(labelx);
labelx.setBounds(10, 10, 30, 20);
add(fieldX);
fieldX.setBounds(30, 10, 100, 20);
fieldX.setEditable(false);
add(labely);
labely.setBounds(10, 40, 30, 20);
add(fieldY);
fieldY.setBounds(30, 40, 100, 20);
fieldY.setEditable(false);
add(label1);
label1.setBounds(200, 10, 150, 20);
add(field1);
field1.setBounds(320, 10, 100, 20);
add(label2);
label2.setBounds(200, 40, 150, 20);
add(field2);
field2.setBounds(320, 40, 100, 20);
addButton = new JButton("添加");
addButton.setEnabled(false);
confirmButton = new JButton("确定");
confirmButton.setEnabled(false);
JLabel label3 = new JLabel("输入学习率:");
field3 = new JTextField();
field3.setEditable(false);
original = new JButton("原始形式");
original.setEnabled(false);
dual = new JButton("对偶形式");
dual.setEnabled(false);
field1.addActionListener(new ActionListener() {
@Override
public void actionPerformed(ActionEvent e) {
// TODO Auto-generated method stub
dim = Integer.parseInt(field1.getText());
textArea.append("\n空间特征维数为:" + dim);
field1.setEditable(false);
field2.setEditable(true);
}
});
field2.addActionListener(new ActionListener() {
@Override
public void actionPerformed(ActionEvent e) {
// TODO Auto-generated method stub
num = Integer.parseInt(field2.getText());
textArea.append("\n训练数据个数为:" + num);
field2.setEditable(false);
fieldX.setEditable(true);
fieldY.setEditable(true);
addButton.setEnabled(true);
xData = new double[num][dim];
yData = new double[num];
}
});
addButton.addActionListener(new ActionListener() {
@Override
public void actionPerformed(ActionEvent e) {
// TODO Auto-generated method stub
if (i_n < num) {
String[] s = fieldX.getText().split(",");
yData[i_n] = Double.parseDouble(fieldY.getText());
textArea.append("\n(" + i_n + "): y=" + yData[i_n]
+ " ");
for (i_d = 0; i_d < dim; i_d++) {
xData[i_n][i_d] = Double.parseDouble(s[i_d]);
textArea.append("x" + i_d + "=" + xData[i_n][i_d]
+ " ");
}
i_n++;
}
if (i_n >= num) {
i_n = 0;
fieldX.setEditable(false);
fieldY.setEditable(false);
addButton.setEnabled(false);
field3.setEditable(true);
}
}
});
add(addButton);
addButton.setBounds(20, 80, 60, 25);
add(confirmButton);
confirmButton.setBounds(100, 80, 60, 25);
add(label3);
label3.setBounds(180, 80, 70, 25);
add(field3);
field3.setBounds(260, 80, 60, 25);
add(original);
original.setBounds(340, 80, 100, 25);
add(dual);
dual.setBounds(450, 80, 100, 25);
field3.addActionListener(new ActionListener() {
@Override
public void actionPerformed(ActionEvent e) {
// TODO Auto-generated method stub
learn = Double.parseDouble(field3.getText());
textArea.append("\n学习率为:" + learn);
field3.setEditable(false);
original.setEnabled(true);
dual.setEnabled(true);
}
});
ActionListener originalActionListener = new OriginalActionListener();
original.addActionListener(originalActionListener);
ActionListener dualActionListener = new DualActionListener();
dual.addActionListener(dualActionListener);
}
}
public class ButtonPanel extends JPanel {
private JButton clearButton;
private JButton resetButton;
public ButtonPanel() {
clearButton = new JButton("清空");
resetButton = new JButton("重置");
inputField = new JTextField("输入预测数据");
predictButton = new JButton("预测");
add(clearButton);
add(resetButton);
add(inputField);
add(predictButton);
inputField.setEditable(false);
predictButton.setEnabled(false);
ActionListener predictActionListener=new PredictActionListener();
predictButton.addActionListener(predictActionListener);
clearButton.addActionListener(new ActionListener() {
@Override
public void actionPerformed(ActionEvent e) {
// TODO Auto-generated method stub
textArea.setText("");
}
});
resetButton.addActionListener(new ActionListener() {
@Override
public void actionPerformed(ActionEvent e) {
// TODO Auto-generated method stub
dim = 1;// 输入空间特征维数
num = 1;// 训练数据个数
xData=null;// 存储训练数据
yData=null;// 存储训练数据的类型
i_n = 0;// 遍历num
i_d = 0;// 遍历dim
learn = 0;// 学习率
w=null;//原始训练
alpha=null;//对偶训练
b=0;
gram=null;
field1.setEditable(true);
field1.setText("");
field2.setText("");
field2.setEditable(false);
fieldX.setText("");
fieldX.setEditable(false);
fieldY.setText("");
fieldY.setEditable(false);
field3.setText("");
field3.setEditable(false);
addButton.setEnabled(false);
confirmButton.setEnabled(false);
field3.setEditable(false);
original.setEnabled(false);
dual.setEnabled(false);
inputField.setEditable(false);
predictButton.setEnabled(false);
}
});
}
}
public class OriginalActionListener implements ActionListener {
@Override
public void actionPerformed(ActionEvent e) {
// TODO Auto-generated method stub
int i = 0;// 记录迭代次数
textArea.append("\n迭代次数\t误分类点\tw\tb\tw*x+b");
textArea.append("\n" + (i++) + "\t初始状态\t0\t0\t0");
w = new double[dim];
b = 0;
boolean flag = true;// flag为true表示存在误分类点
while (flag) {
flag = false;
for (i_n = 0; i_n < num; i_n++) {
if (yData[i_n] * multiply(w, xData[i_n], b) <= 0) {
flag = true;
b = b + learn * yData[i_n];
textArea.append("\n" + (i++) + " (" + i_n + ")" + "\t(");
for (i_d = 0; i_d < dim; i_d++) {
w[i_d] = w[i_d] + learn * xData[i_n][i_d]
* yData[i_n];
if (i_d < (dim - 1)) {
textArea.append(w[i_d] + ",");
} else {
textArea.append(w[i_d] + ")");
}
}
textArea.append("\t" + b + "\t");
for (i_d = 0; i_d < dim; i_d++) {
if ((i_d + 1) < dim && (w[i_d + 1] >= 0)) {
textArea.append(+w[i_d] + "x" + i_d + "+");
} else {
textArea.append(+w[i_d] + "x" + i_d);
}
}
if (b >= 0) {
textArea.append("+" + b);
} else {
textArea.append("" + b);
}
break;
}
}
}
inputField.setEditable(true);
predictButton.setEnabled(true);
}
}
public class DualActionListener implements ActionListener {
@Override
public void actionPerformed(ActionEvent e) {
// TODO Auto-generated method stub
//初始化gram矩阵
gram=new double[num][num];
alpha=new double[num];
b=0;
for(int i=0;i<num;i++){
for(int j=0;j<num;j++){
for(i_d=0;i_d<dim;i_d++){
gram[i][j]=gram[i][j]+xData[i][i_d]*xData[j][i_d];
}
}
}
boolean flag = true;//flag为true表示存在误分类点
int j=0;//记录迭代次数
textArea.append("\n迭代次数\t");
textArea.append("误分类点\t");
for(int i=0;i<num;i++){
textArea.append("alpha"+i+"\t");
}
textArea.append("b");
textArea.append("\n"+(j++)+"\t初始状态\t0\t0\t0\t");
while(flag){
flag=false;
for(i_n=0;i_n<num;i_n++){
if(yData[i_n]*multiply2(alpha, gram, b,i_n)<=0){
flag=true;
alpha[i_n]=alpha[i_n]+learn;
b=b+yData[i_n]*learn;
textArea.append("\n"+(j++)+"\tx"+i_n+"\t");
for(int i=0;i<num;i++){
textArea.append(alpha[i]+"\t");
}
textArea.append(b+"");
break;
}
}
}
inputField.setEditable(true);
predictButton.setEnabled(true);
}
}
public class PredictActionListener implements ActionListener {
@Override
public void actionPerformed(ActionEvent e) {
// TODO Auto-generated method stub
String[] s=inputField.getText().split(",");
double[] x=new double[s.length];
for(int i=0;i<dim;i++){
x[i]=Double.parseDouble(s[i]);
}
if(multiply(w, x, b) <= 0){
textArea.append("\n输入的x是反例");
}else{
textArea.append("\n输入的x是正例");
}
}
}
public double multiply(double[] w2, double[] ds, double b2) {
// TODO Auto-generated method stub
double sum = 0;
for (i_d = 0; i_d < dim; i_d++) {
sum = sum + w2[i_d] * ds[i_d];
}
return sum + b2;
}
public double multiply2(double[] alpha2, double[][] gram2, double b2,
int i_n2) {
// TODO Auto-generated method stub
double sum=0;
for(int i=0;i<num;i++){
sum=sum+alpha2[i]*yData[i]*gram2[i_n2][i];
}
return sum+b2;
}
}
package perceptron;
import java.awt.*;
import javax.swing.*;
public class PerceptronTest {
public static void main(String[] args) {
// TODO Auto-generated method stub
EventQueue.invokeLater(new Runnable() {
@Override
public void run() {
// TODO Auto-generated method stub
JFrame frame=new PerceptronFrame();
frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
frame.setVisible(true);
}
});
}
}