算法来自于《集体智慧编程》-第六章
原书代码用 Python 实现,这两天看这章书,改用 Java 实现。
package ch6DocumentFiltering;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Set;
public class Classifier {
private HashMap<String, Integer[]> fc = new HashMap<String, Integer[]>();
private HashMap<String, Integer> cc = new HashMap<String, Integer>();
private HashMap<String, Integer> catMap = new HashMap<String, Integer>();
private HashMap<String, Double> threshold = new HashMap<String, Double>();
// private String[] features;
public Classifier() {
// fc = null;
// cc = null;
// this.features = getFeatures(content);
}
public double getThreshold(String key) {
if(this.threshold.get(key) == null)
return 1.0;
return this.threshold.get(key);
}
public void setThreshold(String key, double t) {
this.threshold.put(key, t);
}
/**
*
* @param content
* @return
*/
public String[] getFeatures(String content) {
DocClass doc = new DocClass();
return doc.getWords(content);
}
/**
* 增加某一分类的计数值
*
* @param self
* @param key
* @param cat
* @return
*/
public void infc(String key, int cat) {
if (this.fc.get(key) != null) {
Integer[] temp = this.fc.get(key);
if (temp.length > cat) {
Integer[] result = new Integer[temp.length];
for (int i = 0; i < temp.length; i++) {
if (i == cat)
result[i] = temp[i] + 1;
else
result[i] = temp[i];
}
this.fc.put(key, result);
} else {
Integer[] result = new Integer[cat + 1];
for (int i = 0; i < temp.length; i++) {
result[i] = temp[i];
}
for (int j = temp.length; j < cat + 1; j++) {
result[j] = 1;
}
this.fc.put(key, result);
}
} else {
Integer[] result = new Integer[cat + 1];
result[cat] = 1;
this.fc.put(key, result);
}
}
public void incc(String key) {
if (this.cc.get(key) != null)
this.cc.put(key, this.cc.get(key) + 1);
else
this.cc.put(key, 1);
}
/**
* 某一特征出现在某分类中的次数
*
* @param key
* @param cat
* @return
*/
public double fcount(String key, int cat) {
if (this.fc.get(key) != null) {
Integer[] temp = this.fc.get(key);
if (temp.length > cat) {
if (this.fc.get(key)[cat] != null)
return (double) this.fc.get(key)[cat];
}
}
return 0.0;
}
/**
* 某一分类的内容项数量
*
* @param cat
* @return
*/
public int catCount(String cat) {
if (this.cc.get(cat) != null)
return this.cc.get(cat);
return 0;
}
/**
* 所有内容项数量
*
* @return
*/
public int totalCount() {
int count = 0;
for(Iterator<String> i = this.cc.keySet().iterator(); i.hasNext();){
String key = i.next();
count += this.cc.get(key);
}
return count;
}
/**
* 分类列表
*
* @return
*/
public Set<String> getCategories() {
return this.cc.keySet();
}
/**
*
* @param item
* @param cat
*/
public void train(String item, String cat) {
String[] features = getFeatures(item);
int intCat = -1;
if (this.catMap.get(cat) == null) {
this.catMap.put(cat, new Integer(this.catMap.size()));
}
intCat = this.catMap.get(cat);
for (String f : features) {
this.infc(f, intCat);
}
this.incc(cat);
}
/**
* 计算单词咋分类中出现的概率
*
* @param key
* @param cat
* @return
*/
public double fprob(String key, String cat) {
if (this.catCount(cat) == 0)
return 0.0;
return this.fcount(key, this.catMap.get(cat)) / (double)this.catCount(cat);
}
/**
*
* @param key
* @param cat
* @param weight
* @param ap
* @return
*/
public double weightedProb(String key, String cat, double weight, double ap) {
double basicProb = fprob(key, cat);
int totals = 0;
String[] cats = this.getCategories().toArray(new String[0]);
for (String c : cats) {
totals = (int) (totals + this.fcount(key, this.catMap.get(c)));
}
double bp = ((weight * ap) + (totals * basicProb)) / (weight + totals);
return bp;
}
/**
* 找出最可能的分类
*
* @param item
* @param defaultCat
* @return
*/
public String classify(String item, String defaultCat){
String best = defaultCat;
double max = 0.0;
HashMap<String, Double> probs = new HashMap<String, Double>();
Naivebayes n = new Naivebayes();
for(Iterator<String> i = this.getCategories().iterator(); i.hasNext();){
String cat = i.next();
probs.put(cat, n.prob(this, item, cat));
if(probs.get(cat) > max){
max = probs.get(cat);
best = cat;
}
}
for(Iterator<String> i = probs.keySet().iterator(); i.hasNext(); ){
String cat = i.next();
if(cat == best) continue;
if(probs.get(cat)*this.getThreshold(best) > probs.get(best))
return defaultCat;
}
return best;
}
/**
*
* @param c1
*/
public void sampleTrain() {
this.train("the quick brown fox jumps over the lazy dog", "good");
this.train("make quick monkey in the online casino", "bad");
this.train("Nobody owns the water.", "good");
this.train("the quick rabbit jumps fences", "good");
this.train("buy pharmaceuticals now", "bad");
}
public static void main(String[] args) {
Classifier c1 = new Classifier();
c1.sampleTrain();
// c1.setThreshold("bad", 3);
// System.out.println(c1.classify("quick monkey", "unknow"));
Fisher fisher = new Fisher();
// System.out.println(fisher.cprob(c1, "money", "bad"));
System.out.println(fisher.fisherProb(c1, "quick rabbit", "bad"));
}
}
package ch6DocumentFiltering;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Set;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
public class DocClass {
private Pattern p = Pattern.compile("\\w*");
/**
* 得到一组文件中包含的不重复单词
*
* @param content
* @return
*/
public String[] getWords(String content) {
String[] dict;
HashMap<String, Integer> wordsMap = new HashMap<String, Integer>();
Matcher m = p.matcher(content);
while (m.find()) {
int start = m.start();
int end = m.end();
String word = content.substring(start, end).toLowerCase();
if (word.length() < 20 && word.length() > 2) {
Integer count = wordsMap.get(word);
if (count == null)
count = 1;
else
count += 1;
wordsMap.put(word, count);
}
}
Set<String> wordsSet = wordsMap.keySet();
dict = wordsSet.toArray(new String[0]);
return dict;
}
public static void main(String[] args) {
String sample = "A wiki ( /ˈwɪki/ WIK-ee) is a website that allows the easy[1] creation and editing of any number of interlinked web pages via a web browser using a simplified markup language or a WYSIWYG text editor.";
String[] result = new DocClass().getWords(sample);
for (String word : result) {
System.out.println(word);
}
}
}
package ch6DocumentFiltering;
public class Naivebayes {
/**
* pr(Document|Category)
*
* @param c
* @param item
* @param cat
* @return
*/
public double docProb(Classifier c, String item, String cat) {
String[] features = c.getFeatures(item);
double p = 1.0;
for (String f : features) {
p *= c.weightedProb(f, cat, 1.0, 0.5);
}
return p;
}
/**
* pr(Category|Document)*pr(Document)
* @param c
* @param item
* @param cat
* @return
*/
public double prob(Classifier c, String item, String cat){
double catProb = ((double)c.catCount(cat)/c.totalCount());
//System.out.println(c.totalCount());
double douDocProb = docProb(c, item, cat);
return catProb*douDocProb;
}
public static void main(String[] args){
Classifier c = new Classifier();
c.sampleTrain();
Naivebayes n = new Naivebayes();
System.out.println(n.prob(c, "quick monkey", "good"));
System.out.println(n.prob(c, "rabbit", "bad"));
}
}
package ch6DocumentFiltering;
import java.util.HashMap;
import java.util.Iterator;
public class Fisher {
private HashMap<String, Double> minimum = new HashMap<String, Double>();
public double getMinimum(String cat) {
if(this.minimum.get(cat) == null)
return 0.0;
return this.minimum.get(cat);
}
public void setMinimum(String cat, double min) {
this.minimum.put(cat, new Double(min));
}
/**
*
* @param c
* @param f
* @param cat
* @return
*/
public double cprob(Classifier c, String f, String cat) {
// 该特征在某分类中出现的概率
double clf = c.fprob(f, cat);
if (clf == 0)
return 0;
// 该特征在所有分类中出现的概率之和
double freqSum = 0;
for (Iterator<String> i = c.getCategories().iterator(); i.hasNext();) {
String catTemp = i.next();
freqSum += c.fprob(f, catTemp);
}
return clf / freqSum;
}
/**
*
* @param c
* @param item
* @param cat
* @return
*/
public double fisherProb(Classifier c, String item, String cat) {
double p = 1.0;
String[] features = c.getFeatures(item);
for (String f : features) {
p *= c.weightedProb(f, cat, 1.0, 0.5);
}
double fScore = -2 * Math.log(p);
return invchi2(fScore, 2 * features.length);
}
/**
* 倒置<a href = "http://baike.baidu.com/view/859454.htm">对数卡方</a>函数
*
* @param chi
* @param df
* @return
*/
public double invchi2(double chi, double df) {
double m = chi / 2.5;
double sum, term;
sum = term = Math.exp(-m);
int temp = (int) (df / 2);
for (int i = 1; i < temp; i++) {
term *= m / i;
sum += term;
}
return Math.min(sum, 1.0);
}
/**
*
* @param c
* @param item
* @param defaultCat
* @return
*/
public String classify(Classifier c, String item, String defaultCat) {
String best = defaultCat;
double max = 0.0;
for (Iterator<String> i = c.getCategories().iterator(); i.hasNext();) {
String catTemp = i.next();
double p = this.fisherProb(c, item, catTemp);
if (p > this.getMinimum(catTemp) && p > max) {
best = catTemp;
max = p;
}
}
System.out.println(max);
return best;
}
public static void main(String[] args){
Classifier c = new Classifier();
Fisher fisher = new Fisher();
fisher.setMinimum("bad", 0.8);
fisher.setMinimum("good", 0.4);
c.sampleTrain();
System.out.println(fisher.classify(c, "casino", "none"));
}
}