import java.io.BufferedReader;
import java.io.FileReader;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
/**
*
* @author qibaoyuan
*
*/
public class InformationGain {
/**
* calculate the info(entrophy from a list of classes)
*
* @param classes
* 字符类型的分类信息
* @return info entropy
*/
static Double calculateEntrophy(List<String> classes) {
Double info = 0.0;
try {
// 总的个数
int size = classes.size();
// map to store the count of each unique class
Map<String, Integer> counter = new HashMap<String, Integer>();
// iter all the class
for (String key : classes) {
// already exists,incremental
if (counter.containsKey(key.trim()))
counter.put(key.trim(), counter.get(key.trim()) + 1);
else
// set 1
counter.put(key.trim(), 1);
}
// iter the map
for (Entry<String, Integer> entry : counter.entrySet()) {
Double ratio = Double.parseDouble(Integer.toString((entry
.getValue()))) / size;
info -= ratio * (Math.log(ratio) / Math.log(2));
}
} catch (Exception e) {
e.printStackTrace();
}
return info;
}
/**
*
* @param records
* 输入记录 example:{[我 n 1 0 0 0 0 0 YES],[是 n 0 0 0 0 0 0 NO]}
* @return
*/
static Map<Integer, Double> calculateIG(List<String[]> records,
Boolean isSingleFeature) {
Map<Integer, Double> index4select = new HashMap<Integer, Double>();
try {
// 1.计算总的info
List<String> labels = new ArrayList<String>();
int feature_size = 0;
for (String[] arr : records) {
String label = arr[arr.length - 1];
labels.add(label);
feature_size = arr.length - 1;
}
Map<Integer, List<Object>> features = PermutationTest.genPerLess(
feature_size, 3);
Double total = calculateEntrophy(labels);
System.out.print("label的熵信息:");
System.out.println(total);
// 2.计算每个feature的entrophy
// int i=0;
for (Entry<Integer, List<Object>> entry1 : features.entrySet()) {
Double info_i = 0.0;
Map<String, List<String>> featureMap = new HashMap<String, List<String>>();
// divide the records according to the feature
for (String[] arr : records) {
// get the feature
String feature = "";
if (entry1.getValue().size() > 1 && isSingleFeature)
continue;
for (Object obj : entry1.getValue()) {
if (obj instanceof Integer)
feature += arr[(Integer) obj];
}
// check whether if it's counted
if (featureMap.containsKey(feature)) {
List<String> featureList = featureMap.get(feature);
featureList.add(arr[arr.length - 1]);
featureMap.put(feature, featureList);
} else {
List<String> featureList = new ArrayList<String>();
featureList.add(arr[arr.length - 1]);
featureMap.put(feature, featureList);
}
}
// calculate entrophy of each value of the feature
for (Entry<String, List<String>> entry : featureMap.entrySet()) {
Double score = calculateEntrophy(entry.getValue());
info_i += (Double.parseDouble(Integer.toString(entry
.getValue().size())) / records.size()) * score;
}
System.out.print("feature " + entry1.getKey() + " ig:");
System.out.println(total - info_i);
// ig=f-total
index4select.put(entry1.getKey(), total - info_i);
}
// ///sort by the value
ArrayList<Integer> keys = new ArrayList<Integer>(
index4select.keySet());// 得到key集合
final Map<Integer, Double> scoreMap_temp = index4select;
Collections.sort(keys, new Comparator<Object>() {
public int compare(Object o1, Object o2) {
if (Double.parseDouble(scoreMap_temp.get(o1).toString()) < Double
.parseDouble(scoreMap_temp.get(o2).toString()))
return 1;
if (Double.parseDouble(scoreMap_temp.get(o1).toString()) == Double
.parseDouble(scoreMap_temp.get(o2).toString()))
return 0;
else
return -1;
}
});
int y = 0;
for (Integer key : keys) {
System.out.println(key + "" + features.get(key) + "= "
+ scoreMap_temp.get(key));
}
// //
} catch (Exception e) {
e.printStackTrace();
}
return index4select;
}
/**
* 从文件读入输入,计算每个feature的ig,最後一列是手工標註的label
*
* @param file
* 存放手工标注语料的路径
*/
static void calculateIG(String file) {
try {
FileReader reader = new FileReader(file);
BufferedReader br = new BufferedReader(reader);
String line = null;
List<String[]> lists = new ArrayList<String[]>();
while ((line = br.readLine()) != null) {
if (line.trim().length() == 0)
continue;
lists.add(line.split("\t"));
}
System.out.print(calculateIG(lists,false));
} catch (Exception e) {
e.printStackTrace();
}
}
/**
* @param args
*/
public static void main(String[] args) {
calculateIG("/home/qibaoyuan/qibaoyuan/lexo/cv/all.txt");
}
}
计算信息增益(Information Gain),考虑交叉feature
最新推荐文章于 2024-04-19 09:39:59 发布