xgboost参数说明在网上找了很多关于xgboost的文章,基本上90%都是以python在说明的,java的很少,
xgboost参数说明http://blog.csdn.net/zc02051126/article/details/46711047在这篇文章里面说明的很详细,
在java中使用的话,只要:
Map<String, Object> params = new HashMap<String, Object>();
params.put("eta", 1.0); //为了防止过拟合,更新过程中用到的收缩步长。在每次提升计算之后,算法会直接获得新特征的权重
params.put("max_depth", 15);//叔最大深度
params.put("silent", 1); //为1的时候不会打印模型迭代的信息,为0可以看到打印的信息
params.put("lambda", 2);//用于逻辑回归的时候L2正则选项
params.put("min_child_weight", 6);
// params.put("nthread", 6); //不使用的话系统会默认得到最大的线程数目
params.put("objective", "binary:logistic");//目标函数值
关于xgboost数据训练格式,官网DMatrix提供的构造函数主要有三种:
第一种是采用的是l提供ibsvm格式文件所在磁盘路径,官网提供的数据也是这个例子,然后把libsvm格式数据文件转化为DMatrix类,
去看看这个类的源码,也是调用c++底层代码,核心代码还是c++,无论是python、java、scala都值一个外壳。
第二种采用的是LabeledPoint格式,这也是变种libsvm格式文件,用这个不大方便,会把数据缓存到一个目标里面去。
第三种采用的是DMatrix.SparseType,这个我还是比较喜欢,最后转化Dmatrix。
其中预测输入都是用的DMatrix类型参数。
说了这么多,关于模型训练、保存不上代码说明,看看模型预测,用写代码说明下,在git上提供的一个例子再加了两个方法,
这个方法作用是把一行文本转化为DMatrix类型,以提供模型预测:
package com.meituan.model.xgboost;
import java.io.*;
import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;
import java.util.Map.Entry;
import java.util.stream.Collectors;
import ml.dmlc.xgboost4j.java.Booster;
import ml.dmlc.xgboost4j.java.DMatrix;
import ml.dmlc.xgboost4j.java.XGBoostError;
import org.ansj.splitWord.analysis.ToAnalysis;
import org.apache.commons.lang.StringUtils;
import org.apache.commons.lang3.ArrayUtils;
import com.meituan.model.libsvm.TFIDF;
import com.meituan.model.libsvm.Terms;
import com.meituan.nlp.util.TextUtil;
import com.meituan.nlp.util.WordUtil;
public class DataLoader {
public static class DenseData {
public float[] labels;
public float[] data;
public int nrow;
public int ncol;
}
public static class CSRSparseData {
public float[] labels;
public float[] data;
public long[] rowHeaders;
public int[] colIndex;
}
public static DenseData loadCSVFile(String filePath) throws IOException {
DenseData denseData = new DenseData();
File f = new File(filePath);
FileInputStream in = new FileInputStream(f);
BufferedReader reader = new BufferedReader(new InputStreamReader(in,
"UTF-8"));
denseData.nrow = 0;
denseData.ncol = -1;
String line;
List<Float> tlabels = new ArrayList<>();
List<Float> tdata = new ArrayList<>();
while ((line = reader.readLine()) != null) {
String[] items = line.trim().split(",");
if (items.length == 0) {
continue;
}
denseData.nrow++;
if (denseData.ncol == -1) {
denseData.ncol = items.length - 1;
}
tlabels.add(Float.valueOf(items[items.length - 1]));
for (int i = 0; i < items.length - 1; i++) {
tdata.add(Float.valueOf(items[i]));
}
}
reader.close();
in.close();
denseData.labels = ArrayUtils.toPrimitive(tlabels
.toArray(new Float[tlabels.size()]));
denseData.data = ArrayUtils.toPrimitive(tdata.toArray(new Float[tdata
.size()]));
return denseData;
}
public static CSRSparseData loadSVMFile(String filePath) throws IOException {
CSRSparseData spData = new CSRSparseData();
List<Float> tlabels = new ArrayList<>();
List<Float> tdata = new ArrayList<>();
List<Long> theaders = new ArrayList<>();
List<Integer> tindex = new ArrayList<>();
File f = new File(filePath);
FileInputStream in = new FileInputStream(f);
BufferedReader reader = new BufferedReader(new InputStreamReader(in,
"UTF-8"));
String line;
long rowheader = 0;
theaders.add(rowheader);
while ((line = reader.readLine()) != null) {
String[] items = line.trim().split(" ");
if (items.length == 0) {
continue;
}
rowheader += items.length - 1;
theaders.add(rowheader);
tlabels.add(Float.valueOf(items[0]));
for (int i = 1; i < items.length; i++) {
String[] tup = items[i].split(":");
assert tup.length == 2;
tdata.add(Float.valueOf(tup[1]));
tindex.add(Integer.valueOf(tup[0]));
}
}
spData.labels = ArrayUtils.toPrimitive(tlabels
.toArray(new Float[tlabels.size()]));
spData.data = ArrayUtils.toPrimitive(tdata.toArray(new Float[tdata
.size()]));
spData.colIndex = ArrayUtils.toPrimitive(tindex
.toArray(new Integer[tindex.size()]));
spData.rowHeaders = ArrayUtils.toPrimitive(theaders
.toArray(new Long[theaders.size()]));
return spData;
}
public static CSRSparseData getSparseData(String content,Map<String, Terms> termsmap){
if (StringUtils.isBlank(content)) {
return null;
}
Map<String, Long> maps = ToAnalysis
.parse(WordUtil.replaceAllSynonyms(TextUtil.fan2Jian(WordUtil
.replaceAll(content.toLowerCase()))))
.getTerms()
.stream()
.map(x -> x.getName())
.filter(x -> !WordUtil.isStopword(x)
)
.collect(Collectors.groupingBy(p -> p, Collectors.counting()));
if (maps == null || maps.size() == 0) {
return null;
}
int sum = maps.values().stream()
.reduce((result, element) -> result = result + element).get()
.intValue();
Map<Integer, Double> treemap = new TreeMap<Integer, Double>();
for (Entry<String, Long> map : maps.entrySet()) {
String key = map.getKey();
Terms keyword = termsmap.get(key);
double tf = TFIDF.tf(map.getValue(), sum);
if (keyword == null) {
continue;
}
int id = keyword.getId();
double idf = 0;
idf = TFIDF.idf(termsmap.get("documentTotal").getFreq(),
keyword.getFreq());
double tfidf = TFIDF.tfidf(tf, idf);
treemap.put(id, tfidf);
}
if (treemap.size() == 0) {
return null;
}
CSRSparseData spData = new CSRSparseData();
List<Float> tlabels = new ArrayList<>();
List<Float> tdata = new ArrayList<>();
List<Long> theaders = new ArrayList<>();
List<Integer> tindex = new ArrayList<>();
theaders.add(0l);
theaders.add((long) treemap.size());
for (Entry<Integer, Double> map : treemap.entrySet()) {
BigDecimal b = new BigDecimal(Double.toString(map.getValue()));
tdata.add(b.floatValue());
tindex.add(Integer.valueOf(map.getKey()));
}
spData.labels = ArrayUtils.toPrimitive(tlabels
.toArray(new Float[tlabels.size()]));
spData.data = ArrayUtils.toPrimitive(tdata.toArray(new Float[tdata
.size()]));
spData.colIndex = ArrayUtils.toPrimitive(tindex
.toArray(new Integer[tindex.size()]));
spData.rowHeaders = ArrayUtils.toPrimitive(theaders
.toArray(new Long[theaders.size()]));
return spData;
}
public static double getClassify(Booster booster,String content,Map<String, Terms> termsmap) throws XGBoostError{
CSRSparseData spData=getSparseData(content, termsmap);
if(spData==null){
return 0.0;
}
DMatrix data = new DMatrix(spData.rowHeaders, spData.colIndex,
spData.data, DMatrix.SparseType.CSR, 0);
return booster.predict(data)[0][0];
}
}