声明
以下代码只是对tf-idf算法思想的基本实现,因此许多地方需待完善,总结如下:
1.实现逻辑问题:特殊位置、比如段首或者名词(相对于动词),应该有更大的权重;
2.分词前应该对文本进行基本处理:去掉标点,合适的方式调用分词接口,使得文本内容变大时能够分两次调用,但结果相同;
3.速度有待提升:总文本数一星期更新一次就行,关键词所在的文本现测量方式;
实现
package demo.utils;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.http.ResponseEntity;
import org.springframework.stereotype.Component;
import org.springframework.web.client.RestTemplate;
import java.util.*;
import java.util.concurrent.*;
import java.util.function.Function;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
/**
* @author 杜艮魁
* @date 2017/12/1
*/
@Component
public class LTPUtils {
@Value("${demo.ltp-url}")
private String LTPURL;
@Value("${demo.api-key}")
private String apiKey;
private ExecutorService pool;
private final Pattern SUM_PATTERN= Pattern.compile("\\d+(,\\d{3})*\\s条结果");
@Autowired
public LTPUtils() {
ThreadFactory namedThreadFactory = new ThreadFactoryBuilder().setNameFormat("Thread-pool-%d").build();
ExecutorService ex = new ThreadPoolExecutor(5, 20, 0L, TimeUnit.MILLISECONDS, new LinkedBlockingDeque<Runnable>(1024), namedThreadFactory, new ThreadPoolExecutor.AbortPolicy());
pool = ex;
}
/**
* tf-idf算法的实现
* @param content 需要分词、计算词频、逆文本频率、TF-IDF值的内容
* @return 关键词和相应的tf-idf值,按照value值降序排列
*/
public Map<String,Double> tfIdf(String content){
try {
String[] strArrs = getInfoForLTP(content, "ws", "plain");
Map<String, Double> tf = countTF(strArrs);
//获取tf-idf值
Map<String, Double> result = new HashMap<>();
for (Map.Entry<String,Double> ele:tf.entrySet()) {
result.put(ele.getKey(), ele.getValue() * getIDF(ele.getKey()));
}
//根据value值进行排序
result=result.entrySet().stream().sorted(Map.Entry.comparingByValue(Collections.reverseOrder())).collect(Collectors.toMap(
//result=result.entrySet().stream().sorted(Map.Entry.comparingByValue(/*去掉这个将升序排列*/)).collect(Collectors.toMap(
Map.Entry::getKey,
Map.Entry::getValue,
(e1,e2)->e1,
LinkedHashMap::new
));
return result;
}catch (Exception e){
//todo 细分
throw new RuntimeException(e.getMessage());
}
}
/**
* 调用哈工大分词,并返回结果
*
* @param text 要处理的文本
* @param pattern 匹配模式
* @param format 返回数据格式
* @return
*/
public String[] getInfoForLTP(String text, String pattern, String format) throws ClassNotFoundException, ExecutionException, InterruptedException {
String url = LTPURL + "?api_key=" + apiKey + "&text=" + text + "&pattern=" + pattern + "&format=" + format;
RestTemplate restTemplate = new RestTemplate();
Future<ResponseEntity> resp = pool.submit(() -> restTemplate.getForEntity(url, String.class, "分词"));
ResponseEntity<String> respBody=resp.get();
String [] respArrs=respBody.getBody().split(" ");
return respArrs;
}
/**
* 统计词频,频率归一化,即出现次数比总次数
* @param strArrs`
* @return
*/
public Map<String,Double> countTF(String [] strArrs){
Map<String,Long> map=Arrays.stream(strArrs).collect(Collectors.groupingBy(Function.identity(),Collectors.counting()));
map=map.entrySet().stream().sorted(Map.Entry.comparingByValue(Collections.reverseOrder())).collect(Collectors.toMap(
//result=result.entrySet().stream().sorted(Map.Entry.comparingByValue(/*去掉这个将升序排列*/)).collect(Collectors.toMap(
Map.Entry::getKey,
Map.Entry::getValue,
(e1,e2)->e1,
LinkedHashMap::new
));
Map<String,Double> result=new HashMap<>();
map.entrySet().stream().forEach(x->
result.put(x.getKey(),x.getValue()/(double)strArrs.length)
);
return result;
}
/**
* 获取词语str的IDF值
* @param str
* @return
*/
public double getIDF(String str){
RestTemplate restTemplate=new RestTemplate();
String respSum="",respArr="";
try {
respSum = restTemplate.getForObject("https://cn.bing.com/search?q=的", String.class, "总");
respArr = restTemplate.getForObject("https://cn.bing.com/search?q=" + str, String.class, "出现某词文档数");
}catch(Exception e){
e.printStackTrace();
return 0;
}
Long sumResp=666L;
Long arrResp=666L;
Matcher m= SUM_PATTERN.matcher(respSum);
if(m.find()){
String patternStr=m.group();
sumResp= Long.parseLong(patternStr.substring(0,patternStr.indexOf(" 条结果")).replace(",",""));
}
m= SUM_PATTERN.matcher(respArr);
if(m.find()){
String patternStr=m.group();
arrResp= Long.parseLong(patternStr.substring(0,patternStr.indexOf(" 条结果")).replace(",",""));
}
if(sumResp!=666L&&arrResp!=666L){//如果都有返回结果
return Math.log(sumResp/arrResp);
}else{
// throw new RuntimeException("返回结果有误");
System.out.println("返回结果有误:"+str);
return 0;
}
}
}