java在文本分类中卡方的特征选择,
在文本分类的特征选择阶段,一般使用“词汇t与类别c不相关”来做原假设,计算出的开方值越大,说明对原假设的偏离越大,我们越倾向于认为原
假设的反面情况是正确的。选择的过程为每个词计算它与类别c的开方值,从大到小排个序(此时开方值越大越相关),取前k个就可以。所以卡方值越大,
词汇与分类越相关。
最后结果,一个特征对应一个最后的卡方值:
挺 81613.22937057534
特别 80715.8079021637
太 77021.62013902201
按摩 77012.332216897
喜欢 74483.97317313829
好吃 68723.25474753922
推荐 64809.29378047458
朋友 55988.55238512035
里 55371.21236337274
舒服 54673.57763700053
服务员 53918.5463140205
手法 53379.51974620014
体验 52445.83665612151
看代码实现:
package com.meituan.model.libsvm;
public class ChiSquare {
private String word;
private double A;
private double B;
private double C;
private double D;
public double chisq;
public ChiSquare() {
}
public ChiSquare(String word) {
this.word = word;
}
public ChiSquare(String word, int a, int b, int c, int d) {
super();
this.word = word;
this.A = a;
this.B = b;
this.C = c;
this.D = d;
this.chisq = getChisq();
}
public String getWord() {
return word;
}
public void setWord(String word) {
this.word = word;
}
public double getA() {
return A;
}
public void setA(int a) {
this.A = a;
}
public double getB() {
return B;
}
public void setB(int b) {
this.B = b;
}
public double getC() {
return C;
}
public void setC(int c) {
this.C = c;
}
public double getD() {
return D;
}
public void setD(int d) {
this.D = d;
}
public void incrA() {
A = A + 1;
}
public void incrB() {
B++;
}
public void decrementC() {
C--;
}
public void decrementD() {
D--;
}
public void setChisq() {
this.chisq = getChisq();
}
public double getChisq() {
if (A == 0 && B == 0 && C == 0 && D == 0) {
return 0;
} else {
return Math.pow((A * D - B * C) * 1.0, 2) * 155738.0
/ ((A + C) * (A + B) * (B + D) * (C + D));
}
}
@Override
public String toString() {
return String.format(
"word=%s ChiSquare=%f [A=%f , B=%f , C=%f , D=%f ]",
this.word, this.getChisq(), this.A, this.B, this.C, this.D);
}
public static void main(String[] args) {
ChiSquare chiSquare = new ChiSquare("搀扶", 17, 1, 20936, 134784);
System.out.println(chiSquare.toString());
}
}
package com.meituan.model.libsvm;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.OutputStreamWriter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import java.util.stream.Collectors;
import org.ansj.splitWord.analysis.ToAnalysis;
import com.meituan.model.util.Config;
import com.meituan.nlp.util.TextUtil;
import com.meituan.nlp.util.WordUtil;
public class ChisQuareSelect {
private static String inputpath = Config.getString("data.path");
private static Map mapchis = new HashMap();
public static void main(String[] args) {
System.setProperty("java.util.Arrays.useLegacyMergeSort", "true");
int[] arr = getTotal(inputpath);
System.out.println("arr is :"+Arrays.toString(arr));
init(inputpath, arr);
featureselect(inputpath);
System.out.println("排序开始");
List> list = new ArrayList>(
mapchis.entrySet());
Collections.sort(list, new Comparator>() {
@Override
public int compare(Entry o1,
Entry o2) {
double result = o1.getValue().getChisq()
- o2.getValue().getChisq();
if (result > 0) {
return -1;
}else{
return 1;
}
}
});
write(list, "file/chisq.txt", arr);
System.out.println("success");
}
public static void write(List> list, String out, int[] arr) {
BufferedWriter bw = null;
int sum = 0;
for (int ar : arr) {
sum += ar;
}
System.out.println("sum is :"+sum);
try {
bw = new BufferedWriter(new OutputStreamWriter(
new FileOutputStream(out)));
for (Map.Entry map : list) {
bw.write(map.getKey() + " " + map.getValue().getChisq()
+ "\n");
if(Math.random()<0.001){
System.out.println(map.getValue().toString());
}
}
} catch (Exception e) {
e.printStackTrace();
} finally {
if (bw != null) {
try {
bw.close();
} catch (IOException e) {
e.printStackTrace();
}
}
}
}
public static int[] getTotal(String in) {
BufferedReader br = null;
int[] arr = new int[2];
int totalP = 0;
int totalN = 0;
try {
br = new BufferedReader(new InputStreamReader(new FileInputStream(
in)));
String lines = br.readLine();
while (lines != null) {
String label = lines.split("\t")[1].equalsIgnoreCase("-1") ? "-1"
: "1";
if (label.equals("-1")) {
totalN++;
} else {
totalP++;
}
lines = br.readLine();
}
} catch (Exception e) {
e.printStackTrace();
} finally {
if (br != null) {
try {
br.close();
} catch (IOException e) {
e.printStackTrace();
}
}
}
arr[0] = totalP;
arr[1] = totalN;
return arr;
}
public static void init(String in, int[] arr) {
BufferedReader br = null;
try {
br = new BufferedReader(new InputStreamReader(new FileInputStream(
in)));
String lines = br.readLine();
while (lines != null) {
String content = lines.split("\t")[0];
Set sets = ToAnalysis
.parse(WordUtil.replaceAllSynonyms(TextUtil
.fan2Jian(WordUtil.replaceAll(content
.toLowerCase()))))
.getTerms()
.stream()
.map(x -> x.getName())
.filter(x ->WordUtil.isChineseChar(x) && !WordUtil.isStopword(x)
&& !WordUtil.startWithNumeber(x))
.collect(Collectors.toSet());
for (String s : sets) {
if (!mapchis.containsKey(s)) {
mapchis.put(s, new ChiSquare(s, 0, 0, arr[0], arr[1]));
}
}
lines = br.readLine();
}
} catch (Exception e) {
e.printStackTrace();
} finally {
if (br != null) {
try {
br.close();
} catch (IOException e) {
e.printStackTrace();
}
}
}
}
public static void featureselect(String in) {
BufferedReader br = null;
try {
br = new BufferedReader(new InputStreamReader(new FileInputStream(
in)));
String lines = br.readLine();
while (lines != null) {
String content = lines.split("\t")[0];
String label = lines.split("\t")[1].equalsIgnoreCase("-1") ? "-1"
: "1";
Set sets = ToAnalysis
.parse(WordUtil.replaceAllSynonyms(TextUtil
.fan2Jian(WordUtil.replaceAll(content
.toLowerCase()))))
.getTerms()
.stream()
.map(x -> x.getName())
.filter(x -> WordUtil.isChineseChar(x) && !WordUtil.isStopword(x)
&& !WordUtil.startWithNumeber(x))
.collect(Collectors.toSet());
for (String s : sets) {
if ("-1".equalsIgnoreCase(label)) {
mapchis.get(s).incrA();
mapchis.get(s).decrementC();
} else if ("1".equalsIgnoreCase(label)) {
mapchis.get(s).incrB();
mapchis.get(s).decrementD();
}
}
lines = br.readLine();
}
} catch (Exception e) {
e.printStackTrace();
} finally {
if (br != null) {
try {
br.close();
} catch (IOException e) {
e.printStackTrace();
}
}
}
}
}