plsa这里使用em
用EM算法来估计参数主要有两步,在plsa中,E部是根据假设求后验概率P(z|w,d),M部是通过最大化似然函数来求p(w|z),p(z|d),重新估计假设,然后在用假设求后验概率......以此循环。具体的公式可参考上面的文档。
具体的代码为:
import java.io.File;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import com.bj58.data.dataming.machinelearning.util.FileUtil;
import com.bj58.data.dataming.machinelearning.util.IKParticiple;
import com.bj58.data.dataming.machinelearning.util.Participle;
import com.bj58.data.dataming.machinelearning.util.Probability;
public class Plsa2 {
private static Participle participle = new IKParticiple();
private static Set wordSet = new HashSet();
private static List topicList = new ArrayList();
private static List docList = new ArrayList();
private static Map dtmap = new HashMap();
private static Map twmap = new HashMap();
private static Map> dwSetMap = new HashMap>();
private static Map dtwmap = new HashMap();
private static Map dwcount = new HashMap();
static{
topicList.add("1");
topicList.add("2");
topicList.add("3");
}
public static void main(String[] args){
String trainPath = "E:\\学习\\文本处理\\data\\SogouC.reduced.20061127\\SogouC.reduced\\plsa";
readFile(trainPath);
//初始化假设p(z|d)
for(String doc : docList){
double all = 0d;
for(String topic : topicList){
double random = Math.random();
all += random;
Probability p = new Probability(doc,topic);
dtmap.put(p.toString(), random);
}
for(String topic : topicList){
Probability p = new Probability(doc,topic);
dtmap.put(p.toString(), dtmap.get(p.toString())/all);
}
}
//初始化假设p(w|z)
for(String topic : topicList){
double all = 0d;
for(String word : wordSet){
double random = Math.random();
Probability p = new Probability(topic,word);
twmap.put(p.toString(), random);
all += random;
}
for(String word : wordSet){
Probability p = new Probability(topic,word);
twmap.put(p.toString(), twmap.get(p.toString())/all);
}
}
for(int i=0;i<100;i++){
E();
M();
}
//排序
List> orderList=new ArrayList>(twmap.entrySet());
Collections.sort(orderList, new Comparator>() {
public int compare(Map.Entry o1, Map.Entry o2) {
return o2.getValue() > o1.getValue()? 1:-1;
}
});
Map> ml = new HashMap>();
Set set = new HashSet();
for(Map.Entry me : orderList){
String [] words =me.getKey().split(",");
if(set.contains(words[1]))
continue;
set.add(words[1]);
if(ml.get(words[0]) == null){
List list = new ArrayList();
ml.put(words[0], list);
}
ml.get(words[0]).add(words[1]);
}
for(Map.Entry> me : ml.entrySet()){
System.out.println(me.getKey()+"="+me.getValue().toString());
}
}
private static void E(){
for(Map.Entry> me : dwSetMap.entrySet()){
String doc = me.getKey();
for( String word : me.getValue()){
double fenmu = 0d;
for(String topic : topicList){
Probability pdt = new Probability(doc,topic);
Probability ptw = new Probability(topic,word);
double fenzi = dtmap.get(pdt.toString())*twmap.get(ptw.toString());
fenmu += fenzi;
Probability pdtw = new Probability(doc,topic,word);
dtwmap.put(pdtw.toString(), fenzi);
}
for(String topic : topicList){
Probability pdtw = new Probability(doc,topic,word);
dtwmap.put(pdtw.toString(), dtwmap.get(pdtw.toString())/fenmu);
}
}
}
}
private static void M(){
//重新估计p(z|d)
for( String doc : docList){
int fenmu = 0;
for(String word : dwSetMap.get(doc)){
Probability pdw = new Probability(doc,word);
fenmu += dwcount.get(pdw.toString());
}
for(String topic : topicList){
double fenzi = 0d;
for(String word : dwSetMap.get(doc)){
Probability pdw = new Probability(doc,word);
Probability pdtw = new Probability(doc,topic,word);
fenzi += dwcount.get(pdw.toString())*dtwmap.get(pdtw.toString());
}
Probability pdt = new Probability(doc,topic);
dtmap.put(pdt.toString(), fenzi/fenmu);
}
}
//重新估计p(w|z)
for(String topic : topicList){
double fenmu = 0d;
for(String doc : docList){
for(String word : dwSetMap.get(doc)){
Probability pdw = new Probability(doc,word);
Probability pdtw = new Probability(doc,topic,word);
fenmu += dwcount.get(pdw.toString())*dtwmap.get(pdtw.toString());
}
}
for(String word : wordSet){
double fenzi = 0d;
for(String doc : docList){
Probability pdw = new Probability(doc,word);
Probability pdtw = new Probability(doc,topic,word);
if(dwcount.get(pdw.toString())==null)
continue;
if( dtwmap.get(pdtw.toString())==null)
continue;
fenzi += dwcount.get(pdw.toString())*dtwmap.get(pdtw.toString());
}
Probability ptw = new Probability(topic,word);
twmap.put(ptw.toString(), fenzi/fenmu);
}
}
}
private static void readFile(String path){
File ppfile = new File(path);
if( ppfile.isDirectory()){
File [] pfiles = ppfile.listFiles();
for( int i=0;i
String categoryName = pfiles[i].getName();
File[] files = pfiles[i].listFiles();
for( File file : files){
String docName = categoryName+","+file.getName();
docList.add(docName);
List fileContent = FileUtil.readFile(file.getPath(),"GB2312");
List participleResult = participle.getParticipleResultList(fileContent.toString());
Set set = new HashSet();
set.addAll(participleResult);
dwSetMap.put(docName, set);
wordSet.addAll(participleResult);
for(String word : participleResult){
Probability p = new Probability(docName,word);
if( dwcount.get(p.toString()) == null){
dwcount.put(p.toString(), 1);
}else{
dwcount.put(p.toString(), dwcount.get(p.toString())+1);
}
}
}
}
}
}
}
其中一个实体类为:
public class Probability {
private String str1;
private String str2 ;
private String str3;
public Probability(String doc, String topic,String word){
this.str1 = doc;
this.str2 = topic;
this.str3 = word;
}
public Probability(String str1, String str2){
this.str1 = str1;
this.str2 = str2;
}
@Override
public String toString() {
if( str1!= null && str2 != null && str3 != null)
return str1+","+str2+","+str3;
else
return str1+","+str2;
}
}