1.概念
第一篇机器学习-线性回归中介绍了线性回归的概念和模型以及java的实现。其实,逻辑回归的本质上是线性回归,只是在特征到结果的映射中加入了一层函数映射,即先把特征线性求和,然后使用一个函数g(z)进行计算。g(z)函数可以将连续值映射到0和1上。也就是说线性回归的输出时y=f(x)=wx,而逻辑回归的输出是y=g(f(x))=g(wx)。
线性回归和逻辑回归的结构图如下所示:
逻辑回归于线性回归的不同点在于:将线性回归的输出范围,例如从负无穷到正无穷,压缩到0和1之间;把大值压缩到这个这个范围还有个很好的用处,就是可以消除特别冒尖的变量的影响。
2.函数模型
刚刚提到线性回归的输出时y=f(x)=wx,而逻辑回归的输出y=g(f(x))=g(wx)。这里的g(z)函数就是Logistic函数也叫Sigmoid函数,其函数形式如下:
Sigmoid函数又个很漂亮的S形,如下图所示:
其函数的实际描述如下:
给定n个特征x={x1,x2,x3,…,xn},设条件概率p(y=1|x)位观测样本y相对于事件因素x发生的概率,用Sigmoid函数表示如下:
其中,
那么在x条件下y不发生的概率为:
假设现在又m个独立的观测事件:y=(y1 ,y2,…,ym ),则一个事件yi发生的概率为:
因此,m个独立事件出现的似然函数为(因为每个样本都是独立的,所以m个样本出现的概率就是它们各自出现的概率乘积):
然后,我们的目标就是求出使这一似然函数的值最大的参数估计,最大似然估计就是求出参数θ,使得L(θ)取得最大值,对函数L(θ)取对数得到:
最大似然估计就是求使得L(θ)取得最大值时的θ,其实这里可以使用梯度上升法求解,求得的θ就是要求最佳参数。我们可以乘以一个负的系数-1/m,转成梯度下降法求解,L(θ)转成J(θ):
所以,取J(θ)最小值时的θ为要求的最佳参数。
3.梯度下降求解
类似于上一篇写的线性回归的求解过程,这里也选用梯度下降法求解θ。
θ的更新过程如下:
因此,θ的更新可以写成:
接下来的过程可以参考线性回归的知识点,使用批量梯度下降或者随机梯度下降,代码的实现过程,只需要在线性回归的基础上加一个sigmoid函数,稍作修改即可。
4.正则化
对于线性回归或逻辑回归的损失函数构成的模型,可能有些权重很大,有些权重很小,而导致过拟合(通俗的讲就是过分拟合了训练数据),这使得模型的复杂度提高,泛华能力较差(可以理解为模型的一般性,对未知数据的预测能力)。
如下图从左往右分别是欠拟合,合适的拟合,过拟合。
过拟合问题的主要原因是拟合了过多的特征。其解决的主要方法如下:
(1)减少特征数量:
可用人工选择需要保留的特征,或者采用模型选择算法选取特征,例如spark里面的特征选择方法,基于卡方检验的特征选择)。减少特征会失去一些信息,即使特征选的很好。
(2)正则化:
保留所有的特征,但是减少θ的大小。正则化是结构化风险最小话策略的实现,是在经验风险上加一个正则化项或者惩罚项。正则化项一般是模型复杂度的单调递增函数,模型越复杂,正则化项就越大。
为了增强模型的泛化能力,防止训练模型过拟合,特别是对于大量稀疏特征的模型,模型的复杂度比较高,需要进行降维处理,我们需要保证在训练误差最小化的基础上,通过加上正则化项减少模型的复杂度。在逻辑回归中,支持L1和L2正则化。
损失函数如下:
5.java实现
package xudong.Regression;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
public class LogRegression {
/**
* 逻辑回归的实现
*
* 这里的逻辑回归的实现并未加入正则化项
* @author xudong
* @since 2017-8-8
*/
private double [][] trainData;//数据集矩阵
private int row;
private int column;
private double [] theta;//参数
private double alpha;//学习速率
private int iteration;//迭代次数
public LogRegression(String filename){
int rowFile=getRowNumber(filename);//获取输入训练数据的行数
int columnFile=getColumnNumber(filename);//获取训练数据的列数
trainData=new double[rowFile][columnFile+1];//加了一个特征x0 x0==1
this.row=rowFile;
this.column=columnFile;
this.alpha=0.001;
this.iteration=1000;
this.theta=new double[column-1];
initialize_theta();
loadTrainDataFromFile(filename,rowFile,columnFile);
}
public LogRegression(String filename,double alpha,int iteration){
int rowFile=getRowNumber(filename);//获取输入训练数据的行数
int columnFile=getColumnNumber(filename);//获取训练数据的列数
trainData=new double[rowFile][columnFile+1];//加了一个特征x0 x0==1
this.row=rowFile;
this.column=columnFile;
this.alpha=alpha;
this.iteration=iteration;
this.theta=new double[column-1];
initialize_theta();
loadTrainDataFromFile(filename,rowFile,columnFile);
}
/**
* 从文件中加载数据集到trainData中
* @param filename
* @param rowFile
* @param columnFile
*/
private void loadTrainDataFromFile(String filename, int rowFile,int columnFile) {
for(int i=0;i<row;i++){//trainData第一例全是0
trainData[i][0]=1.0;
}
File file=new File(filename);
BufferedReader br=null;
try {
br=new BufferedReader(new FileReader(file));
String temp=null;
int counter=0;
while((counter < row)&&(temp=br.readLine())!=null){
String[] tempData=temp.split(" ");
for(int i=0;i<column;i++){
trainData[counter][i+1]=Double.parseDouble(tempData[i]);
}
counter++;
}
br.close();
} catch (FileNotFoundException e) {
e.printStackTrace();
} catch (IOException e) {
e.printStackTrace();
}
}
/**
* 初始化参数theta
*/
private void initialize_theta() {
for(int i=0;i<theta.length;i++){
theta[i]=1.0;
}
}
/**
* 获取数据集的行数
* @param filename
* @return
*/
private int getRowNumber(String filename){
int count=0;
File file=new File(filename);
BufferedReader br=null;
try {
br=new BufferedReader(new FileReader(file));
while(br.readLine()!=null){
count++;
}
br.close();
} catch (FileNotFoundException e) {
e.printStackTrace();
} catch (IOException e) {
e.printStackTrace();
} finally {
if(br!=null){
try {
br.close();
} catch (IOException e) {
e.printStackTrace();
}
}
}
return count;
}
/**
* 获取数据集的列数(特征维度)
* @param filename
* @return
*/
private int getColumnNumber(String filename){
int count=0;
File file=new File(filename);
BufferedReader br=null;
try {
br=new BufferedReader(new FileReader(file));
String temp=br.readLine();
if(temp!=null){
count=temp.split(" ").length;
}
br.close();
} catch (FileNotFoundException e) {
e.printStackTrace();
} catch (IOException e) {
e.printStackTrace();
} finally {
if(br!=null){
try {
br.close();
} catch (IOException e) {
e.printStackTrace();
}
}
}
return count;
}
/**
* 训练模型并计算theta
*/
public void trainTheta(){
int iteration=this.iteration;
while(iteration-- > 0){
//对每一个thetai 求偏导
double[] partial_derivative=compute_partial_derivative();
for(int i=0;i<theta.length;i++){
theta[i]-=alpha*partial_derivative[i];
}
}
}
private double[] compute_partial_derivative() {
double[] partial_derivative=new double[theta.length];
for(int j=0;j<theta.length;j++){
partial_derivative[j]= compute_partial_derivative_for_theta(j);
}
return partial_derivative;
}
private double compute_partial_derivative_for_theta(int j) {
double sum=0.0;
for(int i=0;i<row;i++){
sum+=h_theta_x_i_minus_y_i_times_x_j_i(i,j);
}
return sum/row;
}
private double h_theta_x_i_minus_y_i_times_x_j_i(int i, int j) {
double[] oneRow=getRow(i);//取一行数据,前面是feature最后一个是y
double result=0.0;
for(int k=0;k<(oneRow.length-1);k++){
result+=theta[k]*oneRow[k];
}
//这个地方加入了sigmoid函数
result=sigmoid(result);
//
result-=oneRow[oneRow.length-1];
result*=oneRow[j];
return result;
}
private double[] getRow(int i) {
return trainData[i];
}
/**
* sigmoid函数
* @param x
*/
private double sigmoid(double x){
return (double)(1.0/(1+Math.exp(-x)));
}
//这里是主方法
public static void main(String[] args) {
String filename="";
LogRegression log=new LogRegression(filename);
log.trainTheta();
//预测的函数可以在这里写,就是一个根据训练完的参数theta求解的过程
//过程省略,请自行补上,也可以打印出theta来查看
}
}