代码
package basic.util;
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
import org.wltea.analyzer.lucene.IKAnalyzer;
import java.io.StringReader;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
public class Tf_tdfUtil2 {
private static final String TOTAL_COUNT1 = "totalCount1";
private static final String TOTAL_COUNT2 = "totalCount2";
private static final String FILE_COUNT = "fileCount";
private static final String TF1 = "tf1";
private static final String TF2 = "tf2";
private static final String TDF1 = "tdf1";
private static final String TDF2 = "tdf2";
private static final double TDF_ASSIST = 1d;
public static void main(String[] args) {
try {
String str1 = TextUtil.getFileContentByPath("E:\\myself\\tf_tdf\\1.txt");
String str2 = TextUtil.getFileContentByPath("E:\\myself\\tf_tdf\\2.txt");
System.out.println("***************************************");
System.out.println("内容1:" + str1);
System.out.println("内容2:" + str1);
System.out.println("内容1与内容2的相似度为(1:完全相似,0:完全不相似)为:" + getTfTdfValue(str1, str1));
System.out.println("***************************************");
System.out.println("***************************************");
System.out.println("内容1:" + str1);
System.out.println("内容2:" + str2);
System.out.println("内容1与内容2的相似度为(1:完全相似,0:完全不相似)为:" + getTfTdfValue(str1, str2));
System.out.println("***************************************");
} catch (Exception e) {
e.printStackTrace();
}
}
public static double getTfTdfValue(String str1, String str2) throws Exception {
List<String> charList1 = cutChar(str1);
List<String> charList2 = cutChar(str2);
Map<String, Map<String, Double>> charVetorTable = getCharVetorTable(charList1, charList2);
return getCos(charVetorTable);
}
public static Map<String, Map<String, Double>> getCharVetorTable(List<String> charList1, List<String> charList2) {
Map<String, Map<String, Double>> vectorMap = new HashMap<>();
charList1.forEach(chars -> {
if (vectorMap.containsKey(chars)) {
Map<String, Double> vector = vectorMap.get(chars);
vector.put(TOTAL_COUNT1, vector.get(TOTAL_COUNT1) + 1);
vector.put(FILE_COUNT, vector.get(FILE_COUNT) + 1);
} else {
Map<String, Double> vector = new HashMap<>();
vector.put(TOTAL_COUNT1, 1d);
vector.put(TOTAL_COUNT2, 0d);
vector.put(FILE_COUNT, 1d);
vectorMap.put(chars, vector);
}
});
charList2.forEach(chars -> {
if (vectorMap.containsKey(chars)) {
Map<String, Double> vector = vectorMap.get(chars);
vector.put(TOTAL_COUNT2, vector.get(TOTAL_COUNT2) + 1);
vector.put(FILE_COUNT, vector.get(FILE_COUNT) + 1);
} else {
Map<String, Double> vector = new HashMap<>();
vector.put(TOTAL_COUNT1, 0d);
vector.put(FILE_COUNT, 1d);
vector.put(TOTAL_COUNT2, 1d);
vectorMap.put(chars, vector);
}
});
double totalCharCount = charList1.size() + charList2.size();
vectorMap.forEach((key, map) -> {
double totalCount1 = map.get(TOTAL_COUNT1);
double totalCount2 = map.get(TOTAL_COUNT2);
double fileCount = map.get(FILE_COUNT);
double tf1 = totalCount1 / totalCharCount;
double tdf1 = Math.log(2 / (fileCount + TDF_ASSIST));
double tf2 = totalCount2 / totalCharCount;
double tdf2 = Math.log(2 / (fileCount + TDF_ASSIST));
map.put(TF1, tf1);
map.put(TDF1, tdf1);
map.put(TF2, tf2);
map.put(TDF2, tdf2);
});
return vectorMap;
}
public static BigDecimal getCos(Map<String, Map<String, Double>> charVectorTable) {
BigDecimal sum1 = BigDecimal.ZERO;
BigDecimal sum2 = BigDecimal.ZERO;
for (Map.Entry<String, Map<String, Double>> entry : charVectorTable.entrySet()) {
Map<String, Double> mapValue = entry.getValue();
sum1 = sum1.add(BigDecimal.valueOf(mapValue.get(TF1)).multiply(BigDecimal.valueOf(mapValue.get(TDF1))));
sum2 = sum2.add(BigDecimal.valueOf(mapValue.get(TF2)).multiply(BigDecimal.valueOf(mapValue.get(TDF2))));
}
BigDecimal result1 = sum2.compareTo(BigDecimal.ZERO) == 0 ? BigDecimal.ZERO : sum1.divide(sum2, 6);
BigDecimal result2 = sum1.compareTo(BigDecimal.ZERO) == 0 ? BigDecimal.ZERO : sum2.divide(sum1, 6);
return result1.compareTo(result2) < 0 ? result1 : result2;
}
public static List<String> cutChar(String str) throws Exception {
if(StringUtils.isEmpty(str)){
throw new Exception("字符串为空");
}
List<String> stringList=new ArrayList<>();
Analyzer analyzer = new IKAnalyzer(true);
StringReader reader = new StringReader(str);
TokenStream ts = analyzer.tokenStream("", reader);
ts.reset();
CharTermAttribute term = ts.getAttribute(CharTermAttribute.class);
while(ts.incrementToken()){
stringList.add(term.toString());
}
reader.close();
return stringList;
}
}
运行图