Mahout贝叶斯算法源码分析(6)

首先更正前篇博客中的错误的地方,看图即可看出来:


可以看到和前面猜测的值不对应,第一个怎么是93563,而不是93564?这个看源码就可以看出来了,我当时没有想透彻,key.get()应该是从零开始的,所以一共有93563个单词,key.get()获得的最大值应该是93562,然后最后加上++部分代码,就是93563了,这个确实是我当时没想好。

接着前篇blog的内容,本次应该分析到第六个Job了,且看源代码:

if (shouldPrune) {
。。。
}
 if (processIdf) {
          TFIDFConverter.processTfIdf(
                 new Path(outputDir, DictionaryVectorizer.DOCUMENT_VECTOR_OUTPUT_FOLDER),
                 outputDir, conf, docFrequenciesFeatures, minDf, maxDF, norm, logNormalize,
                 sequentialAccessOutput, namedVectors, reduceTasks);
      }
第一个if是否进入呢?由于前面没有设置shouldPrune,所以这个值是false的,不进入后面的代码块。而processIdf是设置过的,所以为true,进入后面的代码块。进入processTfIdf方法后,看到下面两个重要的函数,即最后的两个Job了:

makePartialVectors(input,
                         baseConf,
                         datasetFeatures.getFirst()[0],
                         datasetFeatures.getFirst()[1],
                         minDf,
                         maxDF,
                         dictionaryChunk,
                         partialVectorOutputPath,
                         sequentialAccessOutput,
                         namedVector);
PartialVectorMerger.mergePartialVectors(partialVectorPaths,
                                            outputDir,
                                            baseConf,
                                            normPower,
                                            logNormalize,
                                            datasetFeatures.getFirst()[0].intValue(),
                                            sequentialAccessOutput,
                                            namedVector,
                                            numReducers);
本篇主要分析第一个makePartialVectors,其实前篇blog也说过这个和前面的seq2sparse(3)的作用差不多,这个Job调用的TFIDFPartialVectorReducer类却和前面的有点不同,经过了一些转换;下面来看此类的仿制代码(Mapper基本没什么作用):

package mahout.fansy.test.bayes;

import java.io.IOException;
import java.net.URI;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;

import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IOUtils;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.SequenceFile;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.util.ReflectionUtils;
import org.apache.mahout.common.Pair;
import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable;
import org.apache.mahout.math.NamedVector;
import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.SequentialAccessSparseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
import org.apache.mahout.math.map.OpenIntLongHashMap;
import org.apache.mahout.vectorizer.TFIDF;

public class TFIDFPartialVectorReducerFollow {

	private static Configuration conf = new Configuration();
	private static OpenIntLongHashMap dictionary = new OpenIntLongHashMap();
	private static String mapOutPath;
	private static String dictionaryPath;
	private static final TFIDF tfidf = new TFIDF();
	private static int minDf = 1;
	private static long maxDf = 99;
	private static long vectorCount = 18846;
	private static long featureCount=93563;
	private static boolean sequentialAccess=false;
	private static boolean namedVector=true;
	/** 
	 * 初始化各种参数
	 */
	static{
		conf.set("mapred.job.tracker", "ubuntu:9001");
		mapOutPath="hdfs://ubuntu:9000/home/mahout/mahout-work-mahout/20news-vectors/tf-vectors/part-r-00000";
		dictionaryPath="hdfs://ubuntu:9000/home/mahout/mahout-work-mahout/20news-vectors/frequency.file-0";
	}
	public static void main(String[] args) throws IOException {
		setup();
		reduce();
	}

	/**
	 * 获得MakePartialVectors的map输出;
	 * @return
	 * @throws IOException 
	 */
	public static Map<String,List<VectorWritable>> getKeyAndValues() throws IOException{
		Map<String,List<VectorWritable>> map=new HashMap<String,List<VectorWritable>>();
		
	    FileSystem fs = FileSystem.get(URI.create(mapOutPath), conf);
	    Path path = new Path(mapOutPath);

	    SequenceFile.Reader reader = null;
	    try {
	      reader = new SequenceFile.Reader(fs, path, conf);
	      Writable key = (Writable)
	        ReflectionUtils.newInstance(reader.getKeyClass(), conf);
	      Writable value = (Writable)
	        ReflectionUtils.newInstance(reader.getValueClass(), conf);
	      while (reader.next(key, value)) {
	        String k=key.toString();
	        VectorWritable v=(VectorWritable)value;
	        v=new VectorWritable(v.get());  // 第一种方式
	        if(map.containsKey(k)){ //如果包含则把其value值取出来加上一个新的vectorWritable到list中
	        	List<VectorWritable> list=map.get(k);
	        	list.add(v);
	        	map.put(k, list);
	        }else{                 // 否则直接new一个新的list,添加该vectorWritable到list中
	        	List<VectorWritable> list=new ArrayList<VectorWritable>();
	        	list.clear();
	        	list.add(v);
	     //   	List<VectorWritable> listCopy=new ArrayList<VectorWritable>();
	     //   	listCopy.addAll(list);  // 第二种方式
	        	map.put(k, list);
	        	
	        }
	      }
	    } finally {
	      IOUtils.closeStream(reader);
	    }
		return map;
	}
	
	/**
	 * 初始化dictionary
	 */
	public static void setup(){
		Path dictionaryFile = new Path(dictionaryPath);
	    // key is feature, value is the document frequency
	    for (Pair<IntWritable,LongWritable> record 
	         : new SequenceFileIterable<IntWritable,LongWritable>(dictionaryFile, true, conf)) {
	      dictionary.put(record.getFirst().get(), record.getSecond().get());
	    }
	}
	/**
	 * 仿制reduce函数
	 * @throws IOException
	 */
	public static void reduce() throws IOException{
		Map<String,List<VectorWritable>> map=getKeyAndValues();
		Set<String> keySet=map.keySet();
		String key=keySet.iterator().next();
		Iterator<VectorWritable> it = map.get(key).iterator();
	    if (!it.hasNext()) {
	      return;
	    }
	    Vector value = it.next().get();
	    Iterator<Vector.Element> it1 = value.iterateNonZero();
	    Vector vector = new RandomAccessSparseVector((int) featureCount, value.getNumNondefaultElements());
	    while (it1.hasNext()) {
	      Vector.Element e = it1.next();
	      if (!dictionary.containsKey(e.index())) {
	        continue;
	      }
	      long df = dictionary.get(e.index());
	      if (maxDf > -1 && (100.0 * df) / vectorCount > maxDf) {
	        continue;
	      }
	      if (df < minDf) {
	        df = minDf;
	      }
	      vector.setQuick(e.index(), tfidf.calculate((int) e.get(), (int) df, (int) featureCount, (int) vectorCount));
	    }
	    if (sequentialAccess) {
	      vector = new SequentialAccessSparseVector(vector);
	    }
	    
	    if (namedVector) {
	      vector = new NamedVector(vector, String.valueOf(key));
	    }
	    
	    VectorWritable vectorWritable = new VectorWritable(vector);
	  //  context.write(key, vectorWritable);
	    System.out.println(key+", "+vectorWritable);
	  }
	
}
首先getKeyAndValues()和setup()可以暂时不管,这两个只是准备数据的阶段;重点看reduce函数,这里写死了key值,所以效果只能在第一次循环中来看。

reduce函数的大概意思是:针对每篇文章中的单词,如果它出现的次数df满足 (100*df)/vectorCount 》0或者maxDf>1或者df<minDf,那么就退出该个单词的设置(等于是不要这个单词了),其中vectorCount 是一共的文件数;否则就设置该个单词,但是它出现的次数要进行转换,具体公式如下:

sqrt(df)*[log(vectorCount/(df+1)) + 1] ,具体参考下面的代码:

public double calculate(int tf, int df, int length, int numDocs) {
    // ignore length
    return sim.tf(tf) * sim.idf(df, numDocs);
  }
上面代码参考  DefaultSimilarity API 即可看到公式。

分享,快乐,成长


转载请注明出处:http://blog.csdn.net/fansy1990 



  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值