Spark ML Lib中提供了文档转为Tf-Idf加权的向量的功能,但是Tf是用的Hash方式将token进行映射,并且向量直接存储出来的格式并不能直接用于SVM、Naive Bayes等算法,因此需要做一些其它工作:
1.调整向量格式
生成TF部分代码不做改变
JavaRDD<String> text = sc.textFile(inputPath);
JavaPairRDD<String,List<String>> document= text.mapToPair(new PairFunction<String,String,List<String>>(){
public Tuple2<String,List<String>> call(String s){
String str[] = s.split("\t");
return new Tuple2<String,List<String>>(str[0],Arrays.asList(s.split(" ")) );
}
});
HashingTF tf = new HashingTF();
JavaRDD<List<String>> features = document.values();
termFreqs = tf.transform(features);
下面是IDF部分,注意匿名函数里面对向量的形式做了一些改变
IDF idf = new IDF();
JavaRDD<Vector> tfIdfs = idf.fit(termFreqs).transform(termFreqs);
JavaRDD<String> tfidfVector = tfIdfs.map(new Function<Vector,String>(){
public String call(Vector v){
final StringBuilder builder = new StringBuilder();
AbstractFunction2<Object, Object, BoxedUnit> f = new AbstractFunction2<Object, Object, BoxedUnit>() {
public BoxedUnit apply(Object t1, Object t2) {
int dim = (Integer)t1;
if( dim >1){
builder.append( dim+":"+(Double)t2+" " );
}
return BoxedUnit.UNIT;
}
};
v.foreachActive(f);
builder.deleteCharAt( builder.length()-1 );
return builder.toString();
}
});
改变之后的向量就按照稀疏向量的格式保存下来,之后添加上分类标记就可以直接用来跑ML Mlib里面的算法了(如SVM、NaiveBayes)
22909:1.0986122886681098 119158:1.0986122886681098 639018:1.0986122886681098 735243:1.0986122886681098
20154:1.0986122886681098 24456:0.4054651081081644 37117:0.6931471805599453 201116:0.6931471805599453 875579:1.0986122886681098
113009:1.0986122886681098 127612:1.0986122886681098 686294:1.0986122886681098 736858:1.0986122886681098 832444:1.0986122886681098
20250:0.6931471805599453 21644:1.0986122886681098 24456:0.4054651081081644 25105:1.0986122886681098 37117:0.6931471805599453 119301:1.0986122886681098 201116:0.6931471805599453 730991:1.0986122886681098
20250:0.6931471805599453 24456:0.4054651081081644 26469:1.0986122886681098 30340:2.1972245773362196 35828:1.0986122886681098 38271:1.0986122886681098 689163:1.0986122886681098 704478:1.0986122886681098 750005:1.0986122886681098 779641:1.0986122886681098 796407:1.0986122886681098 798459:1.0986122886681098
注意代码中有这么三行代码:
if( dim >1){
builder.append( dim+":"+(Double)t2+" " );
}
这是为了将维度号为0和1的维度给过滤掉,如果不过滤,运行SVM或NB会出现数组越界异常。