训练模型
读取语料和词典进行训练
Scanner getCorpus = new Scanner(new BufferedInputStream(new FileInputStream(new File(corpus))), "UTF-8"); //语料
Scanner getDict = new Scanner(new BufferedInputStream(new FileInputStream(new File(dict))), "UTF-8"); //词典
//pair对即文本属于消极倾向<-1,"文本">
List<Pair<Integer, String>> lines = new LinkedList<Pair<Integer, String>>();
Map<String, Integer> features = new LinkedHashMap<String, Integer>();
//语料处理
while(getCorpus.hasNextLine()){
String line = getCorpus.nextLine();
String[] _ = line.trim().split("\t");
lines.add(new Pair<Integer, String>(_[0].equals("负面")? -1: 1, _[1])); //这里可以根据需要进行多分类
}
//词典处理
for(int i=0; getDict.hasNextLine(); ){
String line = getDict.nextLine();
features.put(line.trim(), ++i);
}
int[] cnt = new int[features.size()];
Arrays.fill(cnt, 0);
for(Pair<Integer, String> pr: lines){
int polarity = pr.getFirst();
String phrase = pr.getSecond();
double sum = 0;
List<Pair<Integer, Integer>> featurePos = StringUtil.featureExtract(phrase, features, 10, 1);
Set<Integer> num = new TreeSet<Integer>();
for(Pair<Integer, Integer> pos : featurePos){
String feature = phrase.substring(pos.getFirst(), pos.getSecond());
int id = features.get(feature);
sum += cnt[id];
cnt[id]++;
num.add(id);
}
//文件输出
if(!featurePos.isEmpty()) {
putProblem.printf("%d", polarity);
for(Integer id: num){
putProblem.printf(" %d:%.6f", id, cnt[id]/Math.sqrt(sum));
cnt[id]=0;
}
putProblem.println();
}
}
//使用lR进行训练
Parameter param = new Parameter(SolverType.L2R_LR, 10, 0.002);
Linear.train(Train.readProblem(new File(problem), -1), param).save(new File(model)); //训练结束获得model文件
加载特征词
Map<String,Integer> featureMap = new HashMap<String, Integer>();
InputStream inputStream = conf.getConfResourceAsInputStream("path://");
BufferedReader br = new BufferedReader(new InputStreamReader(inputStream));
// Load Feature
String line;
int index = 0;
while ((line = br.readLine()) != null) {
line = line.toLowerCase().trim();
if(!featureMap.containsKey(line)){
featureMap.put(line, ++index);
}
}
br.close();
inputStream.close();
获取特征向量
List<String> kws;
private int[] cnt = null;
Arrays.fill(cnt, 0);
ArrayList<Feature> features = new ArrayList<Feature>();
//通过算法获取关键词(正向匹配、反向匹配等算法)
kws = StrAlg.Content(content);
int sum = 0;
Set<Integer> num = new TreeSet<Integer>();
for (String kw : kws)
{
if (featureMap.containsKey(kw)) {
int id = featureMap.get(kw);
sum += cnt[id];
cnt[id]++;
num.add(id);
}
}
for(Integer id : num){
features.add(new FeatureNode(id, cnt[id]/Math.sqrt(sum)));
cnt[id]=0;
}
依据训练出的模型进行预测分类
double threshold = 0.5; //判别域(属于哪个分类)
Feature[] instance = new Feature[features.size()];
features.toArray(instance);
double[] prob = new double[2];
double rst = Linear.predictProbability(model, instance, prob);
if (Math.abs(prob[0]- prob[1]) < 0.1) {
return 0;
}
if (prob[0] > threshold ) {
return prob[0];
} else if (prob[1] > threshold ) {
return prob[1] * (-1);
}