spark mllib源码分析之逻辑回归弹性网络ElasticNet(二)

相关文章
spark mllib源码分析之逻辑回归弹性网络ElasticNet(一)
spark源码分析之L-BFGS
spark mllib源码分析之OWLQN
spark中的online均值/方差统计
spark源码分析之二分类逻辑回归evaluation
spark正则化

2. 训练

2.1. 训练参数设置

设置用于控制训练的参数

  • setMaxIter,最大迭代次数,训练的截止条件,默认100次
  • setFamily,binomial(二分类)/multinomial(多分类)/auto,默认为auto。设为auto时,会根据schema或者样本中实际的class情况设置是二分类还是多分类,最好明确设置
  • setElasticNetParam,弹性参数,用于调节L1和L2之间的比例,两种正则化比例加起来是1,详见后面正则化的设置,默认为0,只使用L2正则化,设置为1就是只用L1正则化
  • setRegParam,正则化系数,默认为0,不使用正则化
  • setTol,训练的截止条件,两次迭代之间的改善小于tol训练将截止
  • setFitIntercept,是否拟合截距,默认为true
  • setStandardization,是否使用归一化,这里归一化只针对各维特征的方差进行
  • setThresholds/setThreshold,设置多分类/二分类的判决阈值,多分类是一个数组,二分类是double值
  • setAggregationDepth,设置分布式统计时的层数,主要用在treeAggregate中,数据量越大,可适当加大这个值,默认为2
  • set*Col,因为训练时使用的样本先被搞成了DataFrame结构,这些是设置列名,方便训练时选取label,weight,feature等列,spark实现了默认libsvm格式的读取,如果需要,可以自己读取样本文件,转成DataFrame格式,并设置这些列名

2.2.数据准备

2.2.1. 样本处理

将封装成DataFrame的输入数据再转成简单结构的instance,包括label,weight,特征,默认每个样本的weight为1

val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
val instances: RDD[Instance] =
  dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd.map {
    case Row(label: Double, weight: Double, features: Vector) =>
      Instance(label, weight, features)
  }

2.2.2. 统计

统计样本每个特征的方差,均值,label的分布情况,用到了MultivariateOnlineSummarizer和MultiClassSummarizer,前面有介绍

val (summarizer, labelSummarizer) = { 
  val seqOp = (c: (MultivariateOnlineSummarizer, MultiClassSummarizer),
    instance: Instance) =>
      (c._1.add(instance.features, instance.weight), c._2.add(instance.label, instance.weight))

  val combOp = (c1: (MultivariateOnlineSummarizer, MultiClassSummarizer),
    c2: (MultivariateOnlineSummarizer, MultiClassSummarizer)) =>
      (c1._1.merge(c2._1), c1._2.merge(c2._2))

  instances.treeAggregate(
    new MultivariateOnlineSummarizer, new MultiClassSummarizer
  )(seqOp, combOp, $(aggregationDepth))
}
//各维特征的weightSum
val histogram = labelSummarizer.histogram
//label非法,主要是label非整数和小于0的情况
val numInvalid = labelSummarizer.countInvalid
val numFeatures = summarizer.mean.size
//如果有截距,相当于增加一维值全为1的特征
val numFeaturesPlusIntercept = if (getFitIntercept) numFeatures + 1 else numFeatures

2.2.3. 其他参数推断

  • class数,如果在schema中指定,就要求其大于等于统计得到的class数,否则取统计得到的class数
val numClasses = MetadataUtils.getNumClasses(dataset.schema($(labelCol))) match {
  case Some(n: Int) =>
    require(n >= histogram.length, s"Specified number of classes $n was " + 
      s"less than the number of unique labels ${histogram.length}.")
    n   
  //最好是labelSummarizer.numClasses
  case None => histogram.length
}
  • 是否多分类,根据family参数确定。如果是二分类,要求numClasses为1或2,返回false;如果是多分类,则返回true;如果是默认的auto,则根据numClasses是否大于2得到
val isMultinomial = $(family) match {
  case "binomial" =>
    require(numClasses == 1 || numClasses == 2, s"Binomial family only supports 1 or 2 " + 
    s"outcome classes but found $numClasses.")
    false
  case "multinomial" => true
  case "auto" => numClasses > 2 
  case other => throw new IllegalArgumentException(s"Unsupported family: $other")
}
  • 系数矩阵包含的系数集个数,我们之前的文章说过,多分类的系数矩阵在实际模型训练时(breeze库)是将多个weight向量是拼在一起的,相当于 [[w11,w12,...],[w21,w22,...],...] ,因此个数与class数相同,这里的系数矩阵是使用矩阵存储的,这个值作为矩阵的行;而二分类因为只有一个weight向量,因此为1
val numCoefficientSets = if (isMultinomial) numClasses else 1
  • 阈值,因为这里兼容了二分类与多分类,thresholds实际是个Array,length应该等于class的个数;如果是二分类,阈值就只有一个,设置在threshold,Double类型
if (isDefined(thresholds)) {
  require($(thresholds).length == numClasses, this.getClass.getSimpleName +
    ".train() called with non-matching numClasses and thresholds.length." +
    s" numClasses=$numClasses, but thresholds has length ${
    $(thresholds).length}")
}

2.3. 确定待训练模型

根据二/多分类,是否拟合截距,L1/L2等确定训练使用的优化方法,损失函数

2.3.1. 拟合截距 && label唯一

判断条件为

$(fitIntercept) && isConstantLabel

label是否唯一的判断

val isConstantLabel = histogram.count(_ != 0.0) == 1

histogram是Array,里面放着样本中各label的数量,也就是说样本里只有一种label。
这种情况返回的系数矩阵为全0的SparseMatrix,对于截距,如果是多分类,返回稀疏向量,向量长度为numClasses,只有index为label的位置有值Double.PositiveInfinity;如果是二分类,返回dense vector,值为Double.PositiveInfinity

2.3.2. 不拟合截距 && label唯一

判断条件为

!$(fitIntercept) && isConstantLabel

此种情况下,算法可能不会收敛,给出了警告信息,但是会继续尝试优化。

2.3.3. 不拟合截距 && 各维特征的特征值相同却非0

判断条件

!$(fitIntercept) && (0 unti
  • 2
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值