Spark MLLib

MLLib是Spark的机器学习库。


MLLib提供了两个包:spark.mlib是建立在rdd的原生api;spark.ml是建立在DataFrame上更高层api,主要用于机器学习的pipeline操作。


推荐使用spark.ml,因为DataFrame具有多样性以及操作柔性。 但Spark会继续支持spark.mllib,spark.mllib将会开发更多的新特性和算法。如果开发者熟悉ML pipline,那就优先对spark.ml开发新算法。(大概的思路是,这两个都重要。底层会持续开发spark.mllib,应用层要用spark.ml。)


MLLib的线性代数包breeze,这个包依赖于netlib-java。


关于spark.ml包

    pipeLine涉及到的主要概念:

        DataFrame:它负责存储数据,也能对数据做各种操作。

        Transformer: 把一种dataframe转换到另一种dataframe。

        Estimator: 它是一个算法,算法对dataframe进行处理,然后产生一个transformer。一个学习算法就是一个Estimator。

        pipeline: 多个transformer和多个estimator组合成一个有序的工作流,执行一个完整的机器学习流程。

        paramter:所有transformer和estimator共享了一个共用接口以指定参数。


    如何工作

        一个pipeline是一系列的stage,每一个stage或者是一个tranformer或者是一个estimator。这些statge是有序的。一个dataFrame从input进去,一个一个stage走过,形成最后结果。对于transformer stage来说,在dataframe上使用一个transform()函数,对estimator来说,它对dataframe使用fit()函数。

        从广义上说,一个pipeline是一个estimator,有fit函数。


    第一个例子:

package my.demo;

import java.util.Arrays;
import java.util.List;

import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.classification.LogisticRegressionModel;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;

public class SparkML1 {

	public static void main(String[] args) {
		// TODO Auto-generated method stub
		SparkConf conf = new SparkConf().setAppName("SparkSQLDemo");
		conf.setMaster("local[2]");
		JavaSparkContext jsc = new JavaSparkContext(conf);
		SQLContext sqlContext = new org.apache.spark.sql.SQLContext(jsc);

		DataFrame training = sqlContext.createDataFrame(Arrays.asList(
				new LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)),
				new LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)),
				new LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)),
				new LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5))),
				LabeledPoint.class);

		// Create a LogisticRegression instance. This instance is an Estimator.
		LogisticRegression lr = new LogisticRegression();
		// Print out the parameters, documentation, and any default values.
		System.out.println("LogisticRegression parameters:\n"
				+ lr.explainParams() + "\n");

		// We may set parameters using setter methods.
		lr.setMaxIter(10).setRegParam(0.01);

		// Learn a LogisticRegression model. This uses the parameters stored in
		// lr.
		LogisticRegressionModel model1 = lr.fit(training);
		// Since model1 is a Model (i.e., a Transformer produced by an
		// Estimator),
		// we can view the parameters it used during fit().
		// This prints the parameter (name: value) pairs, where names are unique
		// IDs for this
		// LogisticRegression instance.
		System.out.println("Model 1 was fit using parameters: "
				+ model1.parent().extractParamMap());

		// We may alternatively specify parameters using a ParamMap.
		ParamMap paramMap = new ParamMap().put(lr.maxIter().w(20)) // Specify 1
																	// Param.
				.put(lr.maxIter(), 30) // This overwrites the original maxIter.
				.put(lr.regParam().w(0.1), lr.threshold().w(0.55)); // Specify
																	// multiple
																	// Params.

		// One can also combine ParamMaps.
		ParamMap paramMap2 = new ParamMap().put(lr.probabilityCol().w(
				"myProbability")); // Change output column name
		ParamMap paramMapCombined = paramMap.$plus$plus(paramMap2);

		// Now learn a new model using the paramMapCombined parameters.
		// paramMapCombined overrides all parameters set earlier via lr.set*
		// methods.
		LogisticRegressionModel model2 = lr.fit(training, paramMapCombined);
		System.out.println("Model 2 was fit using parameters: "
				+ model2.parent().extractParamMap());

		// Prepare test documents.
		DataFrame test = sqlContext.createDataFrame(Arrays.asList(
				new LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)),
				new LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)),
				new LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5))),
				LabeledPoint.class);

		// Make predictions on test documents using the Transformer.transform()
		// method.
		// LogisticRegression.transform will only use the 'features' column.
		// Note that model2.transform() outputs a 'myProbability' column instead
		// of the usual
		// 'probability' column since we renamed the lr.probabilityCol parameter
		// previously.
		DataFrame results = model2.transform(test);
		for (Row r : results.select("features", "label", "myProbability",
				"prediction").collect()) {
			System.out.println("(" + r.get(0) + ", " + r.get(1) + ") -> prob="
					+ r.get(2) + ", prediction=" + r.get(3));
		}

	}

}


        第二个例子,关于pipeline

package my.demo;

import java.io.Serializable;
import java.util.Arrays;

import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.PipelineStage;
import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.ml.feature.HashingTF;
import org.apache.spark.ml.feature.Tokenizer;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
public class SparkML2 {
	
	static public class Document implements Serializable {
		  private long id;
		  private String text;

		  public Document(long id, String text) {
		    this.id = id;
		    this.text = text;
		  }

		  public long getId() { return this.id; }
		  public void setId(long id) { this.id = id; }

		  public String getText() { return this.text; }
		  public void setText(String text) { this.text = text; }
		}

		static public class LabeledDocument extends Document implements Serializable {
		  private double label;

		  public LabeledDocument(long id, String text, double label) {
		    super(id, text);
		    this.label = label;
		  }

		  public double getLabel() { return this.label; }
		  public void setLabel(double label) { this.label = label; }
		}

	public static void main(String[] args) {
		// TODO Auto-generated method stub
		SparkConf conf = new SparkConf().setAppName("SparkSQLDemo");
		conf.setMaster("local[2]");
		JavaSparkContext jsc = new JavaSparkContext(conf);
		SQLContext sqlContext = new org.apache.spark.sql.SQLContext(jsc);

			// Prepare training documents, which are labeled.
			DataFrame training = sqlContext.createDataFrame(Arrays.asList(
			  new LabeledDocument(0L, "a b c d e spark", 1.0),
			  new LabeledDocument(1L, "b d", 0.0),
			  new LabeledDocument(2L, "spark f g h", 1.0),
			  new LabeledDocument(3L, "hadoop mapreduce", 0.0)
			), LabeledDocument.class);

			// Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr.
			Tokenizer tokenizer = new Tokenizer()
			  .setInputCol("text")
			  .setOutputCol("words");
			HashingTF hashingTF = new HashingTF()
			  .setNumFeatures(1000)
			  .setInputCol(tokenizer.getOutputCol())
			  .setOutputCol("features");
			LogisticRegression lr = new LogisticRegression()
			  .setMaxIter(10)
			  .setRegParam(0.01);
			Pipeline pipeline = new Pipeline()
			  .setStages(new PipelineStage[] {tokenizer, hashingTF, lr});

			// Fit the pipeline to training documents.
			PipelineModel model = pipeline.fit(training);

			// Prepare test documents, which are unlabeled.
			DataFrame test = sqlContext.createDataFrame(Arrays.asList(
			  new Document(4L, "spark i j k"),
			  new Document(5L, "l m n"),
			  new Document(6L, "mapreduce spark"),
			  new Document(7L, "apache hadoop")
			), Document.class);

			// Make predictions on test documents.
			DataFrame predictions = model.transform(test);
			for (Row r: predictions.select("id", "text", "probability", "prediction").collect()) {
			  System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> prob=" + r.get(2)
			      + ", prediction=" + r.get(3));
			}

	}

}


    关于Spark.mllib库

MLLib支持local vector局部向量,和local matrics局部矩阵,它们存储在一个集群节点上,而不是像rdd分成若干块存储在不同的节点上。

对于分类和回归问题来说,一个训练样本,在mllib里表示成"labeled point",就是有标记的点。


local vector有两种,一种是密集vector,一种是稀疏vector。这一块存储跟matlab是一样的。






  • 2
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值