使用spark mllib 随机森林算法对文本进行多分类

1、数据准备

20W人工标注文本数据,样本如下:

1#k-v#*亮亮爱宠*波波宠物指甲钳指甲剪附送锉刀适用小型犬及猫特价
1#k-v#*顺丰包邮*宠物药品圣马利诺PowerIgG免疫力球蛋白犬猫细小病毒
1#k-v#*包邮*法国罗斯蔓草本精华宠物浴液薰衣草护色润泽香波拍套餐
1#k-v#*包邮*家朵102宠物沐浴液
1#k-v#*包邮*家朵102宠物沐浴液猫

2、分词

使用ansj包对文本数据去除停用词分词。代码如下:

import java.io.File;
import java.io.IOException;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

import org.ansj.domain.Result;
import org.ansj.domain.Term;
import org.ansj.splitWord.analysis.ToAnalysis;
import org.apache.commons.io.FileUtils;
import org.apache.commons.lang3.StringUtils;

public class Seg{

    private static Set<String> stopwords = new HashSet<String>();
    static{
        File f = new File("");
        try {
            List<String> lines = FileUtils.readLines(f);
            for(String str : lines){
                stopwords.add(str);
            }
        } catch (IOException e) {
            e.printStackTrace();
        }


    }

    public static void main(String[] args) throws IOException {
        File f = new File("");
        File resultFile = new File("");

        List<String> lists = FileUtils.readLines(f);

        int count = 0;
        for(String str : lists){
            count++;
            String index = str.split("#k-v#")[0];
//          System.out.println(count + " " + Integer.parseInt(index));

            Result res = ToAnalysis.parse(str.split("#k-v#")[1]);

            List<Term> terms  = res.getTerms();

            String wordStr = "";
            for(Term t : terms){
                String word = t.getName();
                if(word.length()>1&&!stopwords.contains(word)){
                    wordStr = wordStr + " " +  word;
                }
            }
            if(!StringUtils.isEmpty(wordStr)){
                FileUtils.write(resultFile, index + "#k-v#" + wordStr + "\n" , true);
            }

            System.out.println(count);
        }

    }

3、对分词数据进行tfidf转换

这里我用到工具是sparkmllib的tfidf带的包,代码如下:

import org.apache.spark.ml.feature.{HashingTF, IDF, Tokenizer}

import org.apache.spark.sql.SQLContext
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.types.StructField
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.StringType
import org.apache.spark.sql.Row

//case class FileRecord(index:Int,seg: String)

object TfIdf {

  def main(args: Array[String]) {
    val conf = new SparkConf().setAppName("TfIdfExample")
    val sc = new SparkContext(conf)
    val sqlContext = new SQLContext(sc)


      val schemaString = "index seg"

      val fields = schemaString.split(" ").map(fieldName => StructField(fieldName, StringType, nullable = true))
      val schema = StructType(fields)

      val srcRDD = sc.textFile("/tmp/seg_src.txt", 1).map(x => x.split("#k-v#")).map(attributes => Row(attributes(0), attributes(1).trim))

      val sentenceData = sqlContext.createDataFrame(srcRDD, schema).toDF("label", "seg")

      val tokenizer = new Tokenizer().setInputCol("seg").setOutputCol("words")

      val wordsData = tokenizer.transform(sentenceData)

      val hashingTF = new HashingTF().setInputCol("words").setOutputCol("rawFeatures").setNumFeatures(26)
      val featurizedData = hashingTF.transform(wordsData)
      val idf = new IDF().setInputCol("rawFeatures").setOutputCol("features")
      val idfModel = idf.fit(featurizedData)
      val rescaledData = idfModel.transform(featurizedData)
      rescaledData.select("features", "label").take(3).foreach(println)

      rescaledData.select("features", "label").write.format("json").save("/tmp/tfidf.model")

  }
}

得到的是json数据格式,示例数据如下:

{"features":{"type":0,"size":26,"indices":[0,5,6,7,9,10,14,17,21],"values":[2.028990788466258,1.8600672974067514,1.8464729103095205,2.037399707294254,1.908861495143531,3.6260607728633083,2.0363086347259687,1.8261747092361593,2.0640809711702492]},"label":"1"}
{"features":{"type":0,"size":26,"indices":[7,8,17],"values":[4.074799414588508,2.1216332358971366,1.8261747092361593]},"label":"1"}

4、json数据转libsvm数据格式

因为sparkmllib中随机森林算法需libsvm数据格式,故进行转换,代码如下:

    File f = new File("D:/sogouOutput/json_feature");
        File libsvmFile = new File("D:/sogouOutput/libsvm_feature");

        List<String> features = FileUtils.readLines(f);

        for(String str : features){
            JSONObject obj = new JSONObject(str);

            String label = obj.getString("label");

            JSONArray indexArr = obj.getJSONObject("features").getJSONArray("indices");
            JSONArray valueArr = obj.getJSONObject("features").getJSONArray("values");

            int length = indexArr.length();
            String line = label + " ";
            Map<Integer,Double> indiceAndValue = new TreeMap<Integer,Double>();
            for(int i=0;i<length;i++){
                indiceAndValue.put(indexArr.getInt(i), valueArr.getDouble(i));
//              line = line + indexArr.getInt(i)+":" + valueArr.getDouble(i) + " ";
            }

            //特征索引不能为0,不知为什么。
            if(indiceAndValue.containsKey(0)){
                indiceAndValue.remove(0);
            }

            for(Map.Entry<Integer, Double> m : indiceAndValue.entrySet()){
                line = line + m.getKey()+":" + m.getValue() + " ";
            }
//          System.out.println(StringUtils.substring(line, 0, -1));
            FileUtils.write(libsvmFile, StringUtils.substring(line, 0, -1) + "\n", true);

        }

结果示例数据如下:

1 7:2.037399707294254 
1 1:1.6033119355738932 5:1.8600672974067514 7:4.074799414588508 10:1.8130303864316542 13:2.0344821501999344 15:2.2043195316439834 18:2.0104112775954426 20:2.0108489143639154 25:1.9189925465072746 
1 3:5.510668692397079 5:1.8600672974067514 6:1.8464729103095205 7:4.074799414588508 17:1.8261747092361593 
1 3:5.510668692397079 5:1.8600672974067514 6:1.8464729103095205 7:2.037399707294254 13:2.0344821501999344 17:1.8261747092361593 20:2.0108489143639154 
1 7:2.037399707294254 

5、分类

分类代码如下:

import org.apache.spark.sql.SQLContext
import org.apache.spark.{SparkConf, SparkContext}
// $example on$
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.classification.{RandomForestClassificationModel, RandomForestClassifier}
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorIndexer}
// $example off$

object RandomForestClassifierExample {
  def main(args: Array[String]): Unit = {
    val conf = new SparkConf().setAppName("RandomForestClassifierExample")
    val sc = new SparkContext(conf)
    val sqlContext = new SQLContext(sc)

    // $example on$
    // Load and parse the data file, converting it to a DataFrame.
    val data = sqlContext.read.format("libsvm").load("/tmp/libsvm_feature")

    // Index labels, adding metadata to the label column.
    // Fit on whole dataset to include all labels in index.
    //待征索引必须升序
    val labelIndexer = new StringIndexer().setInputCol("label").setOutputCol("indexedLabel").fit(data)
    // Automatically identify categorical features, and index them.
    // Set maxCategories so features with > 4 distinct values are treated as continuous.
    val featureIndexer = new VectorIndexer().setInputCol("features").setOutputCol("indexedFeatures").setMaxCategories(26).fit(data)

    // Split the data into training and test sets (30% held out for testing)
    val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3))

    // Train a RandomForest model.
    val rf = new RandomForestClassifier().setLabelCol("indexedLabel").setFeaturesCol("indexedFeatures").setNumTrees(10)

    // Convert indexed labels back to original labels.
    val labelConverter = new IndexToString().setInputCol("prediction").setOutputCol("predictedLabel").setLabels(labelIndexer.labels)

    // Chain indexers and forest in a Pipeline
    val pipeline = new Pipeline().setStages(Array(labelIndexer, featureIndexer, rf, labelConverter))

    // Train model.  This also runs the indexers.
    val model = pipeline.fit(trainingData)

    // Make predictions.
    val predictions = model.transform(testData)

    // Select example rows to display.
    predictions.select("predictedLabel", "label", "features").show(5)

    // Select (prediction, true label) and compute test error
    val evaluator = new MulticlassClassificationEvaluator().setLabelCol("indexedLabel").setPredictionCol("prediction").setMetricName("precision")
    val accuracy = evaluator.evaluate(predictions)
    println("Test Error = " + (1.0 - accuracy))

    val rfModel = model.stages(2).asInstanceOf[RandomForestClassificationModel]
    println("Learned classification forest model:\n" + rfModel.toDebugString)
    // $example off$

    sc.stop()
  }
}

在运行过程中,val labelIndexer = new StringIndexer().setInputCol(“label”).setOutputCol(“indexedLabel”).fit(data)
这句代码会报错:

Caused by: java.lang.IllegalArgumentException: requirement failed: indices should be one-based and in ascending order

经查找是因为特征索引不能为0,看它源代码是index作了-1处理导致的。

private[spark] def parseLibSVMRecord(line: String): (Double, Array[Int], Array[Double]) = {
    val items = line.split(' ')
    val label = items.head.toDouble
    val (indices, values) = items.tail.filter(_.nonEmpty).map { item =>
      val indexAndValue = item.split(':')
      val index = indexAndValue(0).toInt - 1 // Convert 1-based indices to 0-based.
      val value = indexAndValue(1).toDouble
      (index, value)
    }.unzip

    // check if indices are one-based and in ascending order
    var previous = -1
    var i = 0
    val indicesLength = indices.length
    while (i < indicesLength) {
      val current = indices(i)
      require(current > previous, s"indices should be one-based and in ascending order;"
        + s""" found current=$current, previous=$previous; line="$line"""")
      previous = current
      i += 1
    }
    (label, indices.toArray, values.toArray)
  }
  • 0
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值