测试数据:链接: https://pan.baidu.com/s/1i7owaXJ 密码: 4wqg
代码记录:
import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.{LabeledPoint, LinearRegressionWithSGD} import org.apache.spark.sql.SparkSession /** * Created by root on 1/12/18. */ object LinearRegressionWithSGDTest { def main(args: Array[String]): Unit = { val path = "/home/enche/data/lpsa.data" val spark = SparkSession.builder().appName("LinearRegressionWithSGD").master("local").getOrCreate() val sc = spark.sparkContext val data = sc.textFile(path) //将数据转成Rdd[LabelPoint]格式 val traindata = data.map( line=>LabeledPoint( line.split(",")(0).toDouble, Vectors.dense(line.split(",")(1).split(" ").map(_.toDouble)) ) ) //train(Rdd[LabelPoint],最大迭代数,学习率,每批次占总数的百分比) val model = LinearRegressionWithSGD.train(traindata,100,0.2,0.1) print(model.weights) } }