因为自身原因最近再学习spark MLlib,看的教材是《spark机器学习》,感觉这本书偏入门并且有很多实操,非常适合新手。下面就是我在学习到第五章关于分类算法的一些要点,最要是通过代码实操,具体算法原理就不介绍。
一、数据来源及开发环境
开发环境:为了方便代码管理这里使用了IDEA集成开发环境,单机进行代码调试感觉很方便嘛,主要环境与我前两篇博客中部署的环境一致。
数据源:机器学习实在中数据的获取很重要,互联网上要找到类似数据非常容易。本实例使用的是Kaggle竞赛数据(相信学习机器学习的都知道这个比赛)。数据是关于网站点击数据,主要用于推荐的页面是短暂流行还是长久流行。下载地址,下载train.tsv的文件,需要注册才能下载。
二、数据预处理
大家下载好数据以后可以通过相应的工具打开看看数据构成。由于数据中第一行为列名,在算法中是用不到的,因此将其删除并存为train_noheader.tsv,linux命令如下:
sed 1d train.tsv >train_noheader.tsv
三、代码解析
使用IDEA新建一个scala class,键入如下代码:
//导入各种类
import org.apache.spark.mllib.classification.LogisticRegressionWithSGD
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.{SparkContext, SparkConf}
/**
* Created by luo on 12/12/15.
*/
object ML_Classification {
def main(args:Array[String]){
//代码初始化的一些步骤
val conf=new SparkConf().setAppName("classification").setMaster("local[2]")
val sc=new SparkContext(conf)
val rawData=sc.textFile("/home/luo/sparkLearning/MLData/train_noheader.tsv")
val records=rawData.map(_.split("\t"))//数据是以\t分割
val data=records.map{point=>
//将数据中的引号全部替换为空
val replaceData=point.map(_.replaceAll("\"",""))
//本数据的头四个字段不会用到,数据的一个字段代表分类的结果,1为长久,0为短暂
val label=replaceData(replaceData.size-1).toInt
val features=replaceData.slice(4,replaceData.size-1).map(x=>if(x=="?") 0.0 else x.