机器学习-逻辑回归-详细示例版

2 篇文章 0 订阅
2 篇文章 0 订阅

在分类问题中,你要预测的变量y是离散的值,适合使用逻辑回归 (Logistic Regression) 算法。

 

通常从最简单的二元分类问题开始解决,即“是不是”的问题。

h_\theta(x)\geq0.5 时,预测 y=1;

h_\theta(x)<0.5 时,预测 y=0;

由此,MINST手写数字识别问题可以被分解为:

是数字0的概率是多少;

是数字1的概率是多少;

...

是数字9的概率是多少;

 

但是,线性回归模型,因为其预测的值可以超越[0,1]的范围,并不适合解决这样的问题。

 

因此引入一个新的模型,逻辑回归,该模型的输出变量范围始终在0和1之间。

逻辑回归模型的假设是: h_\theta(x) = g(\theta^TX)

其中,X代表特征向量,g代表逻辑函数。

常用的逻辑函数为S形函数(Sigmoid function),公式为:

g(z) = \frac{1}{1+e^{-z}}

sigmoid函数

 

再看代价函数:

J(\theta) = \frac{1}{m} \sum_{i=1}^m Cost(h_\theta(x^{(i)}), y^{(i)})

其中

Cost(h_\theta(x^{(i)}), y^{(i)})=

{

if y = 1 : -log(h_\theta(x))

if y = 0 : -log(1-h_\theta(x))

}

因此上式可以合并写成:

Cost(h_\theta(x^{(i)}), y^{(i)})= -ylog(h_\theta(x))-(1-y)log(1-h_\theta(x))

 

y = log(x):(以e为底)

仅看x∈(0, 1]区间部分:

因此 y = -log(x) 在x∈(0, 1]区间时:

因此 y = -log(1-x) 在x∈[0, 1)区间时:

 

 

为什么 Cost(h_\theta(x^{(i)}), y^{(i)}) 在逻辑回归中不能再使用 (h_\theta(x^{(i)})-y^{(i)})^2 ?

因为在线性回归中, h_\theta(x)=\theta^Tx ,则对应的代价函数是凸函数,但是在逻辑回归中, h_\theta(x)= \frac{1}{1+e^{-\theta^Tx}} ,则对应的代价函数是非凸函数。

 

如果代价函数是 (h_\theta(x^{(i)})-y^{(i)})^2

非凸

如果代价函数是 -ylog(h_\theta(x))-(1-y)log(1-h_\theta(x))

 

逻辑回归 Java示例

演示了通过逻辑回归,判断数组是否呈现凸性(即中间数字大于两边数字)。

/** 数据集 */
public class TrainingSet {

    public List<Data> dataList = new ArrayList<>();

    public void add(double y, double... x) {
        Data data = new Data();
        data.x = x;
        data.y = y;
        dataList.add(data);
    }

    public static class Data {
        public double[] x;
        public double y;
    }

}
/** 逻辑回归例子 */
public class LogisticRegression {

    public static void main(String[] args) {
        LogisticRegression lr = new LogisticRegression();
        lr.study();
        lr.test();
    }

    /** 产生模拟数据
     * 该模拟数据符合如下凸特性:即中间数字大于两边数字
     * 例如:x1=0.3, x2=0.4, x3=0.1 为凸,因此 y = 1
     * 例如:x1=0.5, x2=0.4, x3=0.1 为非凸,因此 y = 0
     * */
    TrainingSet mock() {
        TrainingSet ts = new TrainingSet();
        final int size = 200;
        //生成凸数据
        for(int i = 0; i < size; i++) {
            double middle = Math.random();
            double x1 = middle * Math.random();
            double x3 = middle * Math.random();
            ts.add(1, x1, middle, x3);
        }
        //生成非凸数据
        for(int i = 0; i < size; i++) {
            double x1 = Math.random();
            double x3 = Math.random();
            double min = Math.min(x1, x3);
            double middle = min * Math.random();
            ts.add(0, x1, middle, x3);
        }
        return ts;
    }

    double[] θ = new double[]{0,0,0};

    final double alpha = 0.001;

    /** 学习 */
    void study() {
        TrainingSet ts = mock();
        double lastCost = Double.MAX_VALUE;
        int limitTry = 1000;
        while(--limitTry > 0) {
            double cost = calculateCost(ts);
            if(cost > lastCost) {
                break;
            }
            lastCost = cost;

            double[] dθs = new double[θ.length];//要同步更新
            for(int i = 0; i < θ.length; i++) {
                dθs[i] = calculateDelta(ts, i);
            }
            for(int i = 0; i < θ.length; i++) {
                θ[i] -= alpha * dθs[i];
            }
        }
        //System.out.println("limitTry剩余" + limitTry);
        for(int i = 0; i < θ.length; i++) {
            System.out.print("θ"+i+"=" + θ[i] + "  ");
        }
        System.out.println();
    }

    /** 验证测试 */
    void test() {
        TrainingSet ts = mock();
        int testNum = ts.dataList.size();
        int correctNum = 0;
        int errorNum = 0;
        for(TrainingSet.Data data : ts.dataList) {
            //System.out.println("预测值:" + hypothesis(data) + ",实际值:" + data.y);
            double hy = hypothesis(data) >= 0.5 ? 1 : 0;
            if(hy == data.y) {
                correctNum++;
            } else {
                errorNum++;
            }
        }
        System.out.println("测试数量=" + testNum + ", 正确=" + correctNum + ", 错误=" + errorNum + ", 准确率=" + 100*((double)correctNum/testNum) + "%'");
    }

    /** 计算代价 */
    double calculateCost(TrainingSet ts) {
        double sum = 0;
        for(TrainingSet.Data data : ts.dataList) {
//            if(data.y == 1) {
//                sum += -Math.log(hypothesis(data));
//            } else if(data.y == 0) {
//                sum += -Math.log(1-hypothesis(data));
//            }
            sum += (-1 * data.y * (Math.log(hypothesis(data)))) - (1-data.y) * Math.log(1-hypothesis(data));//等同于上面注释部分代码
        }
        sum = sum / ts.dataList.size();
        return sum;
    }

    /** 预测函数 */
    double hypothesis(TrainingSet.Data data) {
        double θTX = 0;
        for(int i = 0; i < data.x.length; i++) {
            θTX += θ[i] * data.x[i];
        }
        double val = 1 / (1 + Math.pow(Math.E, ( -1 * θTX )));
        return val;
    }

    /** 计算代价函数的偏导数
     * @param i 对θi求偏导 */
    double calculateDelta(TrainingSet ts, int i) {
        double sum = 0;
        for(TrainingSet.Data data : ts.dataList) {
            sum += (hypothesis(data) - data.y) * data.x[i];
        }
        sum = sum / ts.dataList.size();
        return sum;
    }

}

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值