import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
public class LinearRegression {
/*
* 训练数据示例:
* x0 x1 x2 y
1.0 1.0 2.0 7.2
1.0 2.0 1.0 4.9
1.0 3.0 0.0 2.6
1.0 4.0 1.0 6.3
1.0 5.0 -1.0 1.0
1.0 6.0 0.0 4.7
1.0 7.0 -2.0 -0.6
注意!!!!x1,x2,y三列是用户实际输入的数据,x0是为了推导出来的公式统一,特地补上的一列。
x0,x1,x2是“特征”,y是结果
h(x) = theta0 * x0 + theta1* x1 + theta2 * x2
theta0,theta1,theta2 是想要训练出来的参数
此程序采用“梯度下降法”
*
*/
private double [][] trainData;//训练数据,一行一个数据,每一行最后一个数据为 y
private int row;//训练数据 行数
private int column;//训练数据 列数
private double [] theta;//参数theta
private double alpha;//训练步长
private int iteration;//迭代次数
public LinearRegression(String fileName)
{
int rowoffile=getRowNumber(fileName);//获取输入训练数据文本的 行数
int columnoffile = getColumnNumber(fileName);//获取输入训练数据文本的 列数
trainData = new double[rowoffile][columnoffile+1];//这里需要注意,为什么要+1,因为为了使得公式整齐,我们加了一个特征x0,x0恒等于1