Spark MLlib SVM 文本分类器实现

好久没写博客了,最近搞了一个文本分类器,在此记录一下:


简介:

支持向量机,因其英文名为 support vector machine,故一般简称 SVM,通俗来讲,它是一种二类分类模型,其基本模型定义为特征空间上的间隔最大的线性分类器,其学习策略便是间隔最大化,最终可转化为一个凸二次规划问题的求解。

1  “机” —— Classification Machine,分类器

2  “支持向量” —— 他们就是离分界线最近的向量。也就是说分界面是靠这些向量确定的,他们支撑着分类面。名字就是这么来的...(就是离最优分类平面最近的离散点,也可以称为向量) 


spark自带了一个svm实现的dome,该dome直接读取保存libsvm所需稀疏向量的文件,但是并未提供向量化方法,需自己调用HashingTF、IDF转换为稀疏向量

代码:

  1. /**
  2.  * SVM分类对象
  3.  * @author wangzengxu
  4.  */
  5. object SVM{

  6.     def main(args: Array[String]){
  7.           
  8.      val Array(
  9.          rightPath,      // 正面训练集路径
  10.          negativePath,   // 负面训练集路径
  11.          waitData,       // 待分类数据存放路径
  12.          vectorsLocl,    // 向量存放路径
  13.          iterativeNum    // 迭代次数
  14.          ) = args 
  15.      
  16.      var sparkconf = new SparkConf().setAppName("wzx_svm_classificationsV2")
  17.      var sc = new SparkContext(sparkconf)
  18.      
  19.      sc.addJar("/usr/wzx/spark/svm/SVM_WZX_lib/IKAnalyzer2012_u6.jar");
  20.      sc.addJar("/usr/wzx/spark/svm/SVM_WZX_lib/lucene-analyzers-common-4.3.0.jar");
  21.      sc.addJar("/usr/wzx/spark/svm/SVM_WZX_lib/lucene-core-4.3.0.jar");
  22.      sc.addJar("/usr/wzx/spark/svm/SVM_WZX_lib/lucene-queryparser-4.3.0.jar");
  23.      
  24.      val train_vectors_local = vectorsLocl+"/train-"+DateUtils.getNowDate()  //训练向量存放目录 
  25.      val wait_vectors_local = vectorsLocl+"/wait-"+DateUtils.getNowDate()    //待分向量存放目录 
  26.      
  27.      val data_path_right = rightPath           //正面训练集文章路径文件 每行一篇
  28.      val data_path_negative = negativePath     //负面训练集文章路径文件 每行一篇
  29.      val data_path_wait = waitData             //待分数据存放路径 
  30.      val iterative_number = iterativeNum.toInt //训练模型迭代次数
  31.      
  32.      /***********************start 分词******************************************/
  33.      
  34.        val right_data = sc.textFile(data_path_right)
  35.        
  36.        val negative_data = sc.textFile(data_path_negative)
  37.        
  38.        val wait_data = sc.textFile(data_path_wait)
  39.        
  40.        //去停用词 
  41.        
  42.        val right_text = right_data.map { x =>
  43.             val str = IKUtils.participle(x)
  44.             (1,str) //正面1
  45.         } 
  46.         
  47.         val negative_text = negative_data.map { x =>
  48.             val str = IKUtils.participle(x)
  49.             (0,str) //负面0
  50.         }
  51.         
  52.         val wait_text = wait_data.map { x =>
  53.             val str = IKUtils.participle(x)
  54.             (2,str) //待分2
  55.         }
  56.         
  57.         val data_all_train = right_text.++(negative_text) //训练集RDD合并
  58.       
  59.      /***********************end 分词******************************************/
  60.  
  61.         
  62.         
  63.     /***********************start 向量化******************************************/
  64.         
  65.        val hashingTF = new HashingTF(Math.pow(2, 18).toInt)
  66.         
  67.         //训练集TF向量化
  68.         val documents_train = data_all_train.map{
  69.           case(num,str) =>
  70.             (num,str.split(" ").toSeq)
  71.         }
  72.        
  73.         val tf_num_pairs_train = documents_train.map {
  74.         case (num,seq) =>
  75.           val tf = hashingTF.transform(seq)
  76.           (num,tf)
  77.         }
  78.         
  79.         //待分类TF向量化
  80.         val documents_wait = wait_text.map{
  81.           case(num,str) =>
  82.             (num,str.split(" ").toSeq)
  83.         }
  84.            
  85.         val tf_num_pairs_wait = documents_wait.map {
  86.           case (num,seq) =>
  87.             val tf = hashingTF.transform(seq)
  88.             (num,tf)
  89.         }
  90.       
  91.         tf_num_pairs_train.cache()
  92.         tf_num_pairs_wait.cache()
  93.      
  94.       
  95.        //利用训练集TF构建IDF MODEL
  96.        val idf = new IDF().fit(tf_num_pairs_train.values)
  97.       
  98.      
  99.       //将训练集tf向量转换成tf-idf向量
  100.       val num_idf_pairs_train = tf_num_pairs_train.mapValues(=> idf.transform(v)) 
  101.       //将待分类数据集tf向量转换成tf-idf向量
  102.       val num_idf_pairs_wait = tf_num_pairs_wait.mapValues(=> idf.transform(v)) 
  103.       
  104.       //格式转换 
  105.       val trainCollection = num_idf_pairs_train.map{
  106.         case(num,vector) => 
  107.            val vectorStr = num +" "+VectorToStr.change(vector)
  108.            vectorStr
  109.       } 
  110.         
  111.       val waitCollection = num_idf_pairs_wait.map{
  112.         case(num,vector) => 
  113.            val vectorStr = num +" "+VectorToStr.change(vector)
  114.            vectorStr
  115.       } 
  116.       
  117.       //落地 (后期可参看MLUtils源码来直接转换为LabeledPoint避免落地)
  118.       trainCollection.coalesce(1).saveAsTextFile(train_vectors_local)
  119.       waitCollection.coalesce(1).saveAsTextFile(wait_vectors_local)
  120.       
  121.     /***********************end 向量化******************************************/
  122.         
  123.     /***********************start SVM模型训练******************************************/
  124.     
  125.         val vectors_train = MLUtils.loadLibSVMFile(sc,train_vectors_local).cache()
  126.         val vectors_wait = MLUtils.loadLibSVMFile(sc,wait_vectors_local).cache()
  127.      
  128.        
  129.         //1 新建SVM模型,并设置训练参数 
  130.         
  131.         val numIterations = iterative_number    //迭代次数,并非越大越好,需根据训练集不断调整来确定该值
  132.         
  133.         val stepSize = 1 
  134.         
  135.         val miniBatchFraction = 1.0             //步长
  136.         
  137.         val model = SVMWithSGD.train(vectors_train, numIterations, stepSize, miniBatchFraction) 
   
    
  1.      /***********************start SVM模型训练******************************************/
  2.         
  3.         
  4.      /***********************start 分类******************************************/
  5.         
  6.         //4 对待分类数据向量进行分类 
  7.       
  8.         println("---------------训练完成------------------------")
  9.         
  10.         val prediction_wait = model.predict(vectors_wait.map(_.features))
  11.         
  12.         println("---------------分类完成------------------------")
  13.       
  14.         prediction_wait.saveAsTextFile("/user/wzx/cs1")
  15.       
  16.         println("---------------保存完成------------------------")
  17.      
  18.      /***********************end 分类******************************************/ 
  19.     
  20.       
  21.    }
  22. }


  1. /**
  2.  * IK分词 去掉停用词处理
  3.  * @author wangzengxu
  4.  *
  5.  */
  6. public class IKUtils {
  7.     
  8.      
  9.     
  10.      public static String participle(String text){
  11.      StringBuffer result = new StringBuffer();
  12.      //读入停用词文件
  13.      BufferedReader StopWordFileBr = new BufferedReader(new InputStreamReader(IKUtils.class.getResourceAsStream("/stopword.dic")));    //注意jar包路径问题
  14.      //用来存放停用词的集合
  15.      Set<String> stopWordSet = new HashSet<String>();
  16.      //初如化停用词集
  17.      String stopWord = null;
  18.      try {
  19.             for(; (stopWord = StopWordFileBr.readLine()) != null;){
  20.              stopWordSet.add(stopWord);
  21.              }
  22.         } catch (IOException e) {
  23.             e.printStackTrace();
  24.         }
  25.      //创建分词对象
  26.      StringReader sr=new StringReader(text);
  27.      IKSegmenter ik=new IKSegmenter(sr, false);
  28.      Lexeme lex=null;
  29.      //分词
  30.      try {
  31.             while((lex=ik.next())!=null){
  32.              //去除停用词
  33.              if(stopWordSet.contains(lex.getLexemeText())) {
  34.              continue;
  35.              }
  36.              result.append(lex.getLexemeText()+" ");
  37.              }
  38.         } catch (IOException e) {
  39.             e.printStackTrace();
  40.         }
  41.      //关闭流
  42.      try {
  43.             StopWordFileBr.close();
  44.         } catch (IOException e) {
  45.             e.printStackTrace();
  46.         }
  47.      return result.toString();
  48.      }
  49.      
  50. }


dome中提供了评分代码,在模型训练时需要根据评分来不断调整迭代次数等来达到满意的精度。当然,这个dome还有很多优化空间



来自 “ ITPUB博客 ” ,链接:http://blog.itpub.net/29754888/viewspace-1967758/,如需转载,请注明出处,否则将追究法律责任。

转载于:http://blog.itpub.net/29754888/viewspace-1967758/

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值