1.概念
在统计学中,线性回归(Linear Regression)是利用被称为线性回归方程的最小平方函数对一个或多个自变量和因变量之间关系进行建模的一种回归分析。这种函数是一个或多个被称为回归系数的模型参数的线性组合。
在回归分析中,只包括一个自变量和一个因变量,且二者的关系可以用一条直线近似表示,这种回归分析被称为一元线性回归分析。如果回归分析包括两个或者两个以上的自变量,且因变量和自变量之间是线性关系,则被称为多元线性回归分析。
假设有一个房屋销售的数据如下:
面积(m^2) 房间数 销售价钱(万元)
123 3 250
150 3 320
87 2 160
102 2 220 ......
假设我们又很多这样的数据,这些就是训练数据,我们希望学习一个模型当新来一个面积数据的时候,我们可以自动预测出其房屋的销售价格。现在,我们用x1代表其房屋面积,x2代表房间数,那么我们可以构建一个二元线性方程估计函数:
2.函数模型
其实无论是一元线性方程还是多元线性方程,可以统一写成如下的格式:
如果想要通过上面建立的模型去预测,那么必须求出θ的值,于是求线性方程则演变成了求解方程的参数θ的值。
当然我们也需要一个机制去评估我们的θ是否是最优的,所以我们需要对我们的h函数进行评估。用来评估的函数一般称为损失函数(loss function),用来描述h函数不好的程度。损失函数如下:
损失函数是x(i)的估计只与真实值y(i)之差的平方和,其中系数1/2是为了求导的时候使得系数为1。如何调整θ以使得J(θ)取得最小值有很多方法,其中有最小二乘法(least square)和梯度下降法等。但是最小二乘法在求解θ时存在的局限性(X要求满秩而且还要求矩阵的逆耗费时间),一般常用的是梯度下降法来求解。
3.梯度下降求解
梯度下降算法有两种,一个是批量梯度下降算法,另一个是随机梯度下降算法
1.批量梯度下降法的流程:
(1)将J(θ)函数对θ进行求偏导,得到每个θ对应的梯度
(2)为了最小化损失函数,所以按照每个参数θ的梯度负方向,来更新每个θ,直到收敛
(3)可以看出来,上面求解的是一个全局的最优解,但是每迭代一次都需要用到训练集的所有数据,如果m很大,那么程序的时间复杂度就很高。
注意: 上面参数alpha是学习速率,决定下降的步伐,如果太小则函数收敛的速度就会很慢,如果太大则会出现有可能越过最小值;初始点不同,获得最小值也不同,因此梯度下降求得的知识局部最小值;越接近最小值,下降速度越慢;计算批量梯度下降算法的时候,计算每一个θ都需要遍历就算所有的样本。
批量梯度下降的步骤可以归纳为:
(1)先确定向下一步的步伐大小,我们称为Learning rate
(2)任意给定一个初始值:θ向量,一般为0向量
(3)确定一个向下的方向,并向下走预先规定的步伐,并更新θ向量
(4)当下降的高度小于某个定义的值,则停止下降
2.随机梯度下降法
刚刚在上面批量梯度下降算法中提到该算法的不足,因为每次计算梯度都需要遍历所有的样本点。这是因为梯度是J(θ)的导数,而J(θ)是需要考虑所有样本的误差和 ,这个方法问题就是,扩展性问题,当样本点很大的时候,基本就没法算了。所以接下来又提出了随机梯度下降算法(stochastic gradient descent )。随机梯度下降算法,每次迭代只是考虑让该样本点的J(θ)趋向最小,而不管其他的样本点,这样算法会很快,但是收敛的过程会比较曲折,整体效果上,大多数时候它只能接近局部最优解,而无法真正达到局部最优解。所以适合用于较大训练集的场景。
算法如下:
即每读一条样本,就迭代对θ更新,判断是否收敛,如果没有收敛,则继续读取样本进行处理。
不过,相比较批量梯度下降算法而言,随机梯度下降算法使得J(θ)趋近与最小值的速度更快,但是有可能会在最小值的周围震荡,造成永不收敛。但是在实践中,大部分值都能够接近与最小值,效果也不错。
注意:对于是否收敛的判断方法如下:参数θ的变化距离为0,或者说变化距离小于某一阈值。为减少计算复杂度,该方法更为推荐使用。
Java利用批量梯度下降算法实现了线性回归的过程:
package xudong.Regression;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
public class LinearRegression {
/**
* 线性回归的实现
* @author xudong
* @since 2017-8-7
*/
private double [][] trainData;//数据集矩阵
private int row;
private int column;
private double [] theta;//参数
private double alpha;//学习速率
private int iteration;//迭代次数
public LinearRegression(String filename){
int rowFile=getRowNumber(filename);//获取输入训练数据的行数
int columnFile=getColumnNumber(filename);//获取训练数据的列数
trainData=new double[rowFile][columnFile+1];//加了一个特征x0 x0==1
this.row=rowFile;
this.column=columnFile;
this.alpha=0.001;
this.iteration=1000;
this.theta=new double[column-1];
initialize_theta();
loadTrainDataFromFile(filename,rowFile,columnFile);
}
public LinearRegression(String filename,double alpha,int iteration){
int rowFile=getRowNumber(filename);//获取输入训练数据的行数
int columnFile=getColumnNumber(filename);//获取训练数据的列数
trainData=new double[rowFile][columnFile+1];//加了一个特征x0 x0==1
this.row=rowFile;
this.column=columnFile;
this.alpha=alpha;
this.iteration=iteration;
this.theta=new double[column-1];
initialize_theta();
loadTrainDataFromFile(filename,rowFile,columnFile);
}
/**
* 从文件中加载数据集到trainData中
* @param filename
* @param rowFile
* @param columnFile
*/
private void loadTrainDataFromFile(String filename, int rowFile,int columnFile) {
for(int i=0;i<row;i++){//trainData第一例全是0
trainData[i][0]=1.0;
}
File file=new File(filename);
BufferedReader br=null;
try {
br=new BufferedReader(new FileReader(file));
String temp=null;
int counter=0;
while((counter < row)&&(temp=br.readLine())!=null){
String[] tempData=temp.split(" ");
for(int i=0;i<column;i++){
trainData[counter][i+1]=Double.parseDouble(tempData[i]);
}
counter++;
}
br.close();
} catch (FileNotFoundException e) {
e.printStackTrace();
} catch (IOException e) {
e.printStackTrace();
}
}
/**
* 初始化参数theta
*/
private void initialize_theta() {
for(int i=0;i<theta.length;i++){
theta[i]=1.0;
}
}
/**
* 获取数据集的行数
* @param filename
* @return
*/
private int getRowNumber(String filename){
int count=0;
File file=new File(filename);
BufferedReader br=null;
try {
br=new BufferedReader(new FileReader(file));
while(br.readLine()!=null){
count++;
}
br.close();
} catch (FileNotFoundException e) {
e.printStackTrace();
} catch (IOException e) {
e.printStackTrace();
} finally {
if(br!=null){
try {
br.close();
} catch (IOException e) {
e.printStackTrace();
}
}
}
return count;
}
/**
* 获取数据集的列数(特征维度)
* @param filename
* @return
*/
private int getColumnNumber(String filename){
int count=0;
File file=new File(filename);
BufferedReader br=null;
try {
br=new BufferedReader(new FileReader(file));
String temp=br.readLine();
if(temp!=null){
count=temp.split(" ").length;
}
br.close();
} catch (FileNotFoundException e) {
e.printStackTrace();
} catch (IOException e) {
e.printStackTrace();
} finally {
if(br!=null){
try {
br.close();
} catch (IOException e) {
e.printStackTrace();
}
}
}
return count;
}
/**
* 训练模型并计算theta
*/
public void trainTheta(){
int iteration=this.iteration;
while(iteration-- > 0){
//对每一个thetai 求偏导
double[] partial_derivative=compute_partial_derivative();
for(int i=0;i<theta.length;i++){
theta[i]-=alpha*partial_derivative[i];
}
}
}
private double[] compute_partial_derivative() {
double[] partial_derivative=new double[theta.length];
for(int j=0;j<theta.length;j++){
partial_derivative[j]= compute_partial_derivative_for_theta(j);
}
return partial_derivative;
}
private double compute_partial_derivative_for_theta(int j) {
double sum=0.0;
for(int i=0;i<row;i++){
sum+=h_theta_x_i_minus_y_i_times_x_j_i(i,j);
}
return sum/row;
}
private double h_theta_x_i_minus_y_i_times_x_j_i(int i, int j) {
double[] oneRow=getRow(i);//取一行数据,前面是feature最后一个是y
double result=0.0;
for(int k=0;k<(oneRow.length-1);k++){
result+=theta[k]*oneRow[k];
}
result-=oneRow[oneRow.length-1];
result*=oneRow[j];
return result;
}
private double[] getRow(int i) {
return trainData[i];
}
//这里是主方法
public static void main(String[] args) {
String filename="";
LinearRegression lr=new LinearRegression(filename);
lr.trainTheta();
}
}