这周打算用word2vec+lstm做一个中文文本分类模型,无奈老大以前用过libsvm,叫我用libsvm,折腾了两天基本上调通
中通碰到各种各样的问题,在此记录下来。
首先下载libsvm包,下载链接
http://www.csie.ntu.edu.tw/~cjlin/cgi-bin/libsvm.cgi?+http://www.csie.ntu.edu.tw/~cjlin/libsvm+zip libsvm下载工具,下载之后解压 进入目录直接make命令既可以
然后把文本数据规范成如下格式:
2 2017:1.23527900896424 2080:1.3228803416955244 21233:3.475992040593523
2 576:1.0467435856485432 967:1.0968877798239958 3940:1.7482714392181495 4449:1.7535719911308003
2 967:1.0968877798239958 1336:1.3551722790297116 5611:1.8303003497257173 14735:1.7682821161365336
1 7:0.02425295226485008 32:0.009012036411194203 80:0.0057407001135544745 127:0.020374370371014396
标准的libsvm格式,分词用的是ansj工具,转化数值是tf-idf格式,其中特征的索引一定要按顺序排序,否则用libsvm工具训练的时候会爆如下错误:
Libsvm : Wrong input format at line 1
具体使用可以参考这篇博客:http://endual.iteye.com/blog/1267442,关键是要知道怎么生成libsvm格式文件,这个是关键。
下面贴上把文本转化为libsvm的格式工具的代码,用了许多1.8的特性,习惯了写scala,突然用java感觉很繁琐,见谅:
package com.meituan.model.libsvm;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.OutputStream;
import java.io.OutputStreamWriter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import java.util.TreeMap;
import java.util.stream.Collectors;
import org.ansj.splitWord.analysis.ToAnalysis;
import org.apache.commons.lang3.StringUtils;
import com.meituan.nlp.util.WordUtil;
import com.meituan.nlp.util.TextUtil;
import com.meituan.model.util.Config;
public class DocumentTransForm {
private static String inputpath = Config.getString("data.path");
private static String outputpath = Config.getString("data.libsvm");
private static Map<String, Terms> mapTerms = new HashMap<String, Terms>();
public static int documentTotal = 0;
public static void getTerms(String file) {
BufferedReader br = null;
try {
br = new BufferedReader(new InputStreamReader(new FileInputStream(
file)));
String lines = br.readLine();
int featurecount = 1;
while (lines != null) {
String line = lines.split("\t")[0];
Set<String> sets = ToAnalysis
.parse(WordUtil.replaceAllSynonyms(TextUtil
.fan2Jian(WordUtil.replaceAll(line
.toLowerCase()))))
.getTerms()
.stream()
.map(x -> x.getName())
.filter(x -> !WordUtil.isStopword(x) && x.length() > 1
&& !WordUtil.startWithNumeber(x))
.collect(Collectors.toSet());
if (sets != null) {
for (String key : sets) {
if (!mapTerms.containsKey(key)) {
Terms terms = new Terms(key, featurecount);
mapTerms.put(key, terms);
featurecount++;
} else {
mapTerms.get(key).incrFreq();
}
}
documentTotal++;
}
lines = br.readLine();
}
} catch (Exception e) {
e.printStackTrace();
} finally {
if (br != null) {
try {
br.close();
} catch (IOException e) {
e.printStackTrace();
}
}
}
}
public static void getLibsvmFile(String input, String output) {
BufferedReader br = null;
BufferedWriter bw = null;
try {
br = new BufferedReader(new InputStreamReader(new FileInputStream(
input)));
bw = new BufferedWriter(new OutputStreamWriter(
new FileOutputStream(output)));
String lines = br.readLine();
while (StringUtils.isNoneBlank(lines)) {
String label = lines.split("\t")[1].equalsIgnoreCase("-1") ? "2"
: "1";
String content = lines.split("\t")[0];
Map<String, Long> maps = ToAnalysis
.parse(WordUtil.replaceAllSynonyms(TextUtil
.fan2Jian(WordUtil.replaceAll(content
.toLowerCase()))))
.getTerms()
.stream()
.map(x -> x.getName())
.filter(x -> !WordUtil.isStopword(x) && x.length() > 1
&& !WordUtil.startWithNumeber(x))
.collect(
Collectors.groupingBy(p -> p,
Collectors.counting()));
if (maps != null && maps.size() > 0) {
StringBuffer sb = new StringBuffer();
sb.append(label).append(" ");
int sum = maps
.values()
.stream()
.reduce((result, element) -> result = result
+ element).get().intValue();
Map<Integer, Double> treeMap = new TreeMap<>();
for (Entry<String, Long> map : maps.entrySet()) {
String key = map.getKey();
double tf = TFIDF.tf(map.getValue(), sum);
// 这个key一定存在
double idf = TFIDF.idf(documentTotal, mapTerms.get(key)
.getFreq());
treeMap.put(mapTerms.get(key).getId(),
TFIDF.tfidf(tf, idf));
}
treeMap.forEach((x, y) -> sb.append(x).append(":")
.append(y).append(" "));
bw.write(sb.toString());
bw.newLine();
}
lines = br.readLine();
}
} catch (Exception e) {
e.printStackTrace();
} finally {
try {
bw.close();
br.close();
} catch (Exception e) {
e.printStackTrace();
}
}
}
public static void main(String[] args) {
getTerms(inputpath);
System.out.println("documentTotal is :" + documentTotal);
getLibsvmFile(inputpath, outputpath);
List<String> list = new ArrayList<String>(Arrays.asList("a", "a"));
Map<String, Long> map = list.stream().collect(
Collectors.groupingBy(p -> p, Collectors.counting()));
System.out.println(map.values().stream()
.reduce((result, element) -> result = result + element).get()
.intValue());
}
}
然后开始训练,首先对数据标准化:
./svm-scale -l 0 -u 1 /Users/shuubiasahi/Documents/workspace/spark-model/file/libsvem.txt >/Users/shuubiasahi/Documents/workspace/spark-model/file/libsvem_scale.txt
开始训练,libsvm提供了若干的参数 ,运行./svm-train,可以看到
./svm-train -h 0 -t 0 /Users/shuubiasahi/Documents/workspace/spark-model/file/libsvem_scale.txt /Users/shuubiasahi/Documents/workspace/spark-model/file/model.txt
svm的理论我个人认为还是比较简单,可以看李航老师那本统计学习方法,一看就明白。