StreamingLogisticRegression计算部分源码解读

Streaming Logistic Regression计算部分源码解读

大家好,我是一拳就能打爆A柱的猛男

最近重新调整了时间,以后源码部分和翻译同时做,可能进度慢一点,但是尽量两天一更才行。今天给大家带来流式逻辑回归(Streaming Logistic Regression)计算部分的源码解读,下面我将从下面几个部分来讲解:逻辑回归,程序入口,流式逻辑回归源码。

1、逻辑回归

​ 机器学习解决问题的时候,需要将问题进行归类,大致分为两种:数值预测和分类。针对数值预测问题一般采用回归模型;而针对分类问题可以选择的方法就比较多了,包括:SVM、KNN、朴素贝叶斯等等。而逻辑回归也是针对分类问题设计的,虽然逻辑回归有回归二字,但其解决的确实是分类问题(利用激活函数)。

​ 从名字可以看出来,逻辑回归的本质还是回归,所以其基本思路还是回归的思路(权重向量与样本向量的操作),但是它通过激活函数就可以做到分类。简单来说,逻辑回归 = 线性回归 + 激活函数。下面我将介绍三个常见的激活函数:Sigmoid、tanh、relu。

1.1 Sigmoid激活函数

Sigmoid函数公式如下:
y = 1 1 + e − z y = \frac{1}{1+e^{-z}} y=1+ez1
其函数如图1,Sigmoid函数以(0,0.5)作为中心,将z值(回归值)作为自变量,z越接近0则y越接近0.5,反之z值越远离0则y越接近0或1,以y=0.5为分割线可以做分类。

图1 Sigmoid函数

1.2 Tanh激活函数

Tanh函数公式如下:
y = e x − e − x e x + e − x y = \frac{e^x - e^{-x}}{e^x + e^{-x}} y=ex+exexex
如图2所示,Tanh函数与Sigmoid函数类似,x值(回归值)越接近0则y值越接近0,反之x值越远离,则y越接近1或-1。以y=0为界做分类。

在这里插入图片描述

图2 Tanh函数

1.3 Relu激活函数

Relu函数公式如下:
y = m a x ( 0 , x ) y = max(0,x) y=max(0,x)
如图3所示,Relu函数与前两种有很大区别,其梯度不是0就是1,十分易于计算。以y=0为界分类。

在这里插入图片描述

图3 Relu函数

2、程序入口

​ 由Spark官方提供的案例可以得知模型的具体用法,同时也可以在trainOn中找到端倪,与《StreamingLinearRegressionWithSGD源码分析 流式线性回归源码分析》的流程相似,通过trainOn方法中的lgorithm.run方法进入GeneralizedLinearAlgorithm。看过之前流式线性回归解读的朋友应该记得,该文件是通用线性回归算法,会想到刚刚对于逻辑回归的定义:逻辑回归 = 线性回归 + 激活函数,所以在训练阶段逻辑回归与线性回归使用通用线性算法且步骤相同是理所当然的。

​ 故逻辑回归不应该关注其训练部分(在分析线性回归时了解过),而应该关注其预测部分(predictOn)。

3、流式逻辑回归源码

根据流式逻辑回归模型的继承树可以知道,其实际使用的是批处理下的逻辑回归模型LogisticRegressionModel,而该模型继承了可分类模型ClassificationModel特质:

class LogisticRegressionModel @Since("1.3.0") (
    @Since("1.0.0") override val weights: Vector,
    @Since("1.0.0") override val intercept: Double,
    @Since("1.3.0") val numFeatures: Int,
    @Since("1.3.0") val numClasses: Int)
extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable

所以只需要在predictOn找到预测方法,最后来到LogisticRegressionModel找到对应的方法即可:

/**
   * Predict the result given a data point and the weights learned.
   *
   * @param dataMatrix Row vector containing the features for this data point
   * @param weightMatrix Column vector containing the weights of the model
   * @param intercept Intercept of the model.
   */
protected def predictPoint(dataMatrix: Vector, weightMatrix: Vector, intercept: Double): Double

在LogisticRegressionModel中,这次我决定在源码中打注释,这样大家就不需要来回切换对照源码了:

override protected def predictPoint(
    dataMatrix: Vector, // 样本向量
    weightMatrix: Vector, // 模型训练出的权重向量
    intercept: Double) = { // 偏置
    // 验证样本向量的size是否等于模型的特征数,不满足则抛出异常
    require(dataMatrix.size == numFeatures)
    // 对于二分类做预测
    if (numClasses == 2) {
        // 求出回归值margin 即z
        val margin = dot(weightMatrix, dataMatrix) + intercept
        // 根据Sigmoid函数计算y,Sigmoid函数在1.1节介绍
        val score = 1.0 / (1.0 + math.exp(-margin))
        /* private var threshold: Option[Double] = Some(0.5)
         * threshold作为阀门函数,默认阀门t = 0.5
         * 对默认阀门t = 0.5做判断,若score>t则为1否则为0
         */
        threshold match {
            case Some(t) => if (score > t) 1.0 else 0.0
            case None => score
        }
    } else { // 对于多分类做预测
        
        var bestClass = 0 // 用于标记相似度最大的分类
        var maxMargin = 0.0 // 用于记录最大回归值
        val withBias = dataMatrix.size + 1 == dataWithBiasSize // 是否带偏置
        (0 until numClasses - 1).foreach { i => // 遍历每个分类
            var margin = 0.0 // 临时变量回归值
            // 接下来计算该样本对于每一类别的回归值,并且维护最大回归值及其对应类别
            dataMatrix.foreachNonZero { (index, value) =>
                margin += value * weightsArray((i * dataWithBiasSize) + index)
            }
            // Intercept is required to be added into margin.
            if (withBias) {
                margin += weightsArray((i * dataWithBiasSize) + dataMatrix.size)
            }
            if (margin > maxMargin) {
                maxMargin = margin
                bestClass = i + 1
            }
        }
        bestClass.toDouble
    }
}

​ 综上,对于二分类问题,Spark中的逻辑回归选择使用Sigmoid作为激活函数,对于多分类问题,逻辑回归选择维护最大回归值及其对应的类别。这就是整个流式逻辑回归的激活函数部分的内容。所以,如果能看清楚Spark的通用线性回归模型及算法这两个类的代码,那么逻辑回归只需要去了解其激活函数部分即可。

总结

​ 本来我以为逻辑回归部分又是一套庞大复杂的继承关系,所以我选择先去写了几篇论文的翻译。现在再来看流式逻辑回归的源码,其实并不需要过多的准备,当然前提是要弄清楚线性回归部分的训练代码。经过翻阅几个流式算法的源码,我对整个Spark Streaming的框架的理解又多了一点,Streaming是一个秒级的伪流式大数据计算框架,其内部基于DStream对指定间隔时间内的数据做批处理。按照我的理解,Streaming的框架设计的十分巧妙,基于RDD的DStream负责伪流式的批处理操作,离线批处理机器学习算法库MLlib负责对数据做训练,而从流式到批式之间的桥梁就是DStream。训练的模型通过shuffle机制更新到主机中,而主机又将模型发送至各个从节点进行下一批的训练,如此往复形成循环。

RDD的DStream负责伪流式的批处理操作,离线批处理机器学习算法库MLlib负责对数据做训练,而从流式到批式之间的桥梁就是DStream。训练的模型通过shuffle机制更新到主机中,而主机又将模型发送至各个从节点进行下一批的训练,如此往复形成循环。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值