上一节我们已经对训练集建立了word-doc矩阵,每读取矩阵的一行就可以计算出term对应的IG值。最后把结果写入文件。
信息增益的计算公式是什么我就不介绍了。
代码如下:
View Code
import java.io.BufferedReader;
import java.io.File;
import java.io.FileOutputStream;
import java.io.FileReader;
import java.io.PrintWriter;
import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.Date;
public class CalIG {
public void calIG(File matrixFile,File IGFile) {
if (!matrixFile.exists()) {
System.out.println("Matrix文件不存在.程序退出.");
System.exit(2);
}
int category_num = 9; //一共有9大分类
int doc_num=7196;//总共有7196篇文档,也是word-doc矩阵的列数
int[] category_count={1070,440,513,816,750,756,1392,473,986}; //每个分类包含的文档数
double HC=getEntropy(category_count);
try {
FileReader fr = new FileReader(matrixFile);
BufferedReader br = new BufferedReader(fr);
PrintWriter pw=new PrintWriter(new FileOutputStream(IGFile));
String line = null;
while ((line = br.readLine()) != null) {
String[] content = line.split("\\s+");
String term = content[0];
ArrayList<Short> al = new ArrayList<Short>(doc_num);
for (int i = 0; i < doc_num; i++) {
short count = Short.parseShort(content[i + 1]);
al.add(count);
}
int term_count = 0; // 出现term的文档数量
int[] term_class_count = new int[category_num];// 每个类别中出现term的文档数量
int[] term_b_class_count = new int[category_num];// 每个类别中不出现term的文档数量
int index=0;
for (int i = 0; i < category_num; i++) {
for (int j = 0; j < category_count[i]; j++) {
if (al.get(index) > 0) {
term_class_count[i]++;
}
index++;
}
term_b_class_count[i]= category_count[i]-term_class_count[i];
term_count += term_class_count[i];
}
double HCT=1.0*term_count/doc_num*getEntropy(term_class_count)+1.0*(doc_num-term_count)/doc_num*getEntropy(term_b_class_count);
double IG = HC - HCT;
pw.println(term+"\t"+String.valueOf(IG));
pw.flush();
}
br.close();
pw.close();
} catch (Exception e) {
e.printStackTrace();
}
}
public double getEntropy(int[] arr){
int sum=0;
double entropy=0.0;
for(int i=0;i<arr.length;i++){
sum+=arr[i];
entropy+=arr[i]*Math.log(arr[i]+Double.MIN_VALUE)/Math.log(2);
}
entropy/=sum;
entropy-=Math.log(sum)/Math.log(2);
return 0-entropy;
}
public static void main(String[] args) throws Exception{
Date currentTime = new Date();
SimpleDateFormat formatter = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss");
System.out.println("Begin Time: "+formatter.format(currentTime));
CalIG inst=new CalIG();
File in=new File("/home/orisun/matrix/part-r-00000");
File out=new File("/home/orisun/frequency1");
inst.calIG(in, out);
currentTime = new Date();
System.out.println("End Time: "+formatter.format(currentTime));
}
}
本来生成了大约32.6万个term-ig对,后来我去掉那些以数字、字母、符号开头的term,剩下117342个term。假如需要4000个特征项,那么我们就要按取出ig值最大的前4000个term(注意并不需要对117342条记录全部排序)。可以看我的另外一篇博客《寻找N个元素中的前K个最大者》,这里就不再重复介绍了,直接上代码:
#ifndef ITERM_H
#define ITERM_H
#include<string>
using namespace std;
class iterm{
private:
string term;
double ig;
public:
iterm(string term,double ig):term(term),ig(ig){}
bool operator == (const iterm & i2){
return ig==i2.ig;
}
bool operator > (const iterm & i2){
return ig>i2.ig;
}
bool operator < (const iterm & i2){
return ig<i2.ig;
}
ostream& operator << (ostream& out){
out<<term<<"\t"<<ig;
return out;
}
string getTerm(){
return term;
}
double getIG(){
return ig;
}
};
#endif
#include<iostream>
#include<cstdlib>
#include<ctime>
#include<vector>
#include<fstream>
#include<sstream>
#include"iterm.h"
template<typename Comparable>
void percolate(vector<Comparable> &vec,int index){
int i=index;
int j=2*i+1;
while(j<vec.size()){
if(j<vec.size()-1 && vec[j]>vec[j+1])
j++;
if(vec[i]<vec[j])
break;
else{
swap(vec[i],vec[j]);
i=j;
j=2*i+1;
}
}
}
template<typename Comparable>
void buildHeap(vector<Comparable> &vec){
int len=vec.size();
for(int i=(len-1)/2;i>=0;i--)
percolate(vec,i);
}
int main(){
clock_t t1=clock();
const int K=4000;
vector<iterm> vec;
string infn="frequency1";
string outfn="features2";
ifstream infile(infn.c_str(),ios::in);
ofstream outfile(outfn.c_str(),ios::out);
string line,word,sd;
double d;
int n=K;
while(getline(infile,line)){
istringstream sstr(line);
sstr>>word;
sstr>>sd;
d=atof(sd.c_str());
iterm inst(word,d);
if(n>0){
vec.push_back(inst);
n--;
}
else{
if(n==0){
buildHeap(vec);
n=-1;
}
if(inst>vec[0]){
vec[0]=inst;
percolate(vec,0);
}
}
}
infile.close();
for(int i=0;i<K;i++)
outfile<<vec[i].getTerm()<<"\t"<<vec[i].getIG()<<endl;
outfile.close();
clock_t t2=clock();
cout<<"Time:"<<(t2-t1)<<endl;
return 0;
}
这段时间一直Java,今天又体验了一把C++的极速,执行上面代码只用了0.25秒!
如果实际工作需要,可以再对这4000个term来一次内排序。