项目共需要两个java文件LR.java;Test.java
运行结果(由于采用了随机生成数据集,所以结果也是随机的):
LR.java:
public class LR
{
private double k;
private double b;
private double[] X;
private double[] y;
private double learningRate = 0.01; //学习率
public LR()
{
k = 5;
b = 10;
}
public void fit(double[] X,double[] y)
{
this.X = X;
this.y = y;
for(int i = 0;i < 10000;i++)
{
for(int j=0;j<X.length;j++)
{
b = b - learningRate * (k * X[j] + b - y[j]) * 1;
k = k - learningRate * (k * X[j] + b - y[j]) * X[j];
}
}
}
public double[] predict(double[] x)
{
double[] yhat = new double[x.length];
for(int i=0;i<x.length;i++)
{
yhat[i]= k * x[i] + b;
}
return yhat;
}
public double Lost()
{
double sum = 0;
for(int i=0;i<X.length;i++)
{
double err = k * X[i]+ b - y[i];
sum += err * err;
}
return sum / (2 * X.length);
}
public double getK()
{
return k;
}
public double getB()
{
return b;
}
public double getLearningRate()
{
return learningRate;
}
public void setLearningRate(double value)
{
learningRate = value;
}
}
Test.java:
import java.util.Random;
public class Test {
/**
* @param args
*/
public static void main(String[] args) {
// TODO Auto-generated method stub
double[] X = new double[20];
double[] y = new double[20];
Random random = new Random();
for(int i=0;i<X.length;i++)
{
X[i] = i;
y[i] = 2 * X[i] + 1 + random.nextDouble();
}
LR lr = new LR();
//控制学习率
lr.setLearningRate(0.001);
lr.fit(X, y);
System.out.println("斜率:"+lr.getK() + "\n" +"截距:"+ lr.getB());
}
}