xgboost参数说明,模型训练,模型预测java接口相关说明


xgboost参数说明在网上找了很多关于xgboost的文章,基本上90%都是以python在说明的,java的很少,

xgboost参数说明http://blog.csdn.net/zc02051126/article/details/46711047在这篇文章里面说明的很详细,

在java中使用的话,只要:

	Map<String, Object> params = new HashMap<String, Object>();
		params.put("eta", 1.0); //为了防止过拟合,更新过程中用到的收缩步长。在每次提升计算之后,算法会直接获得新特征的权重
		params.put("max_depth", 15);//叔最大深度
		params.put("silent", 1); //为1的时候不会打印模型迭代的信息,为0可以看到打印的信息
		params.put("lambda", 2);//用于逻辑回归的时候L2正则选项
		params.put("min_child_weight", 6);
//		 params.put("nthread", 6);  //不使用的话系统会默认得到最大的线程数目
		params.put("objective", "binary:logistic");//目标函数值


关于xgboost数据训练格式,官网DMatrix提供的构造函数主要有三种:

第一种是采用的是l提供ibsvm格式文件所在磁盘路径,官网提供的数据也是这个例子,然后把libsvm格式数据文件转化为DMatrix类,

去看看这个类的源码,也是调用c++底层代码,核心代码还是c++,无论是python、java、scala都值一个外壳。


第二种采用的是LabeledPoint格式,这也是变种libsvm格式文件,用这个不大方便,会把数据缓存到一个目标里面去。


第三种采用的是DMatrix.SparseType,这个我还是比较喜欢,最后转化Dmatrix。


其中预测输入都是用的DMatrix类型参数。



说了这么多,关于模型训练、保存不上代码说明,看看模型预测,用写代码说明下,在git上提供的一个例子再加了两个方法,

这个方法作用是把一行文本转化为DMatrix类型,以提供模型预测:

package com.meituan.model.xgboost;

import java.io.*;
import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;
import java.util.Map.Entry;
import java.util.stream.Collectors;

import ml.dmlc.xgboost4j.java.Booster;
import ml.dmlc.xgboost4j.java.DMatrix;
import ml.dmlc.xgboost4j.java.XGBoostError;

import org.ansj.splitWord.analysis.ToAnalysis;
import org.apache.commons.lang.StringUtils;
import org.apache.commons.lang3.ArrayUtils;

import com.meituan.model.libsvm.TFIDF;
import com.meituan.model.libsvm.Terms;
import com.meituan.nlp.util.TextUtil;
import com.meituan.nlp.util.WordUtil;

public class DataLoader {
	public static class DenseData {
		public float[] labels;
		public float[] data;
		public int nrow;
		public int ncol;
	}

	public static class CSRSparseData {
		public float[] labels;
		public float[] data;
		public long[] rowHeaders;
		public int[] colIndex;
	}

	public static DenseData loadCSVFile(String filePath) throws IOException {
		DenseData denseData = new DenseData();

		File f = new File(filePath);
		FileInputStream in = new FileInputStream(f);
		BufferedReader reader = new BufferedReader(new InputStreamReader(in,
				"UTF-8"));

		denseData.nrow = 0;
		denseData.ncol = -1;
		String line;
		List<Float> tlabels = new ArrayList<>();
		List<Float> tdata = new ArrayList<>();

		while ((line = reader.readLine()) != null) {
			String[] items = line.trim().split(",");
			if (items.length == 0) {
				continue;
			}
			denseData.nrow++;
			if (denseData.ncol == -1) {
				denseData.ncol = items.length - 1;
			}

			tlabels.add(Float.valueOf(items[items.length - 1]));
			for (int i = 0; i < items.length - 1; i++) {
				tdata.add(Float.valueOf(items[i]));
			}
		}

		reader.close();
		in.close();

		denseData.labels = ArrayUtils.toPrimitive(tlabels
				.toArray(new Float[tlabels.size()]));
		denseData.data = ArrayUtils.toPrimitive(tdata.toArray(new Float[tdata
				.size()]));

		return denseData;
	}

	public static CSRSparseData loadSVMFile(String filePath) throws IOException {
		CSRSparseData spData = new CSRSparseData();

		List<Float> tlabels = new ArrayList<>();
		List<Float> tdata = new ArrayList<>();
		List<Long> theaders = new ArrayList<>();
		List<Integer> tindex = new ArrayList<>();

		File f = new File(filePath);
		FileInputStream in = new FileInputStream(f);
		BufferedReader reader = new BufferedReader(new InputStreamReader(in,
				"UTF-8"));

		String line;
		long rowheader = 0;
		theaders.add(rowheader);
		while ((line = reader.readLine()) != null) {
			String[] items = line.trim().split(" ");
			if (items.length == 0) {
				continue;
			}

			rowheader += items.length - 1;
			theaders.add(rowheader);
			tlabels.add(Float.valueOf(items[0]));

			for (int i = 1; i < items.length; i++) {
				String[] tup = items[i].split(":");
				assert tup.length == 2;

				tdata.add(Float.valueOf(tup[1]));
				tindex.add(Integer.valueOf(tup[0]));
			}
		}

		spData.labels = ArrayUtils.toPrimitive(tlabels
				.toArray(new Float[tlabels.size()]));
		spData.data = ArrayUtils.toPrimitive(tdata.toArray(new Float[tdata
				.size()]));
		spData.colIndex = ArrayUtils.toPrimitive(tindex
				.toArray(new Integer[tindex.size()]));
		spData.rowHeaders = ArrayUtils.toPrimitive(theaders
				.toArray(new Long[theaders.size()]));

		return spData;
	}

	public static CSRSparseData getSparseData(String content,Map<String, Terms> termsmap){
		if (StringUtils.isBlank(content)) {
			return null;
		}
		Map<String, Long> maps = ToAnalysis
				.parse(WordUtil.replaceAllSynonyms(TextUtil.fan2Jian(WordUtil
						.replaceAll(content.toLowerCase()))))
				.getTerms()
				.stream()
				.map(x -> x.getName())
				.filter(x -> !WordUtil.isStopword(x)
						)
				.collect(Collectors.groupingBy(p -> p, Collectors.counting()));
		if (maps == null || maps.size() == 0) {
			return null;
		}
		int sum = maps.values().stream()
				.reduce((result, element) -> result = result + element).get()
				.intValue();
		Map<Integer, Double> treemap = new TreeMap<Integer, Double>();

		for (Entry<String, Long> map : maps.entrySet()) {
			String key = map.getKey();
			Terms keyword = termsmap.get(key);
			double tf = TFIDF.tf(map.getValue(), sum);
			if (keyword == null) {
				continue;
			}
			int id = keyword.getId();
			double idf = 0;
			idf = TFIDF.idf(termsmap.get("documentTotal").getFreq(),
					keyword.getFreq());
			double tfidf = TFIDF.tfidf(tf, idf);
			treemap.put(id, tfidf);
		}

		if (treemap.size() == 0) {
			return null;
		}

		CSRSparseData spData = new CSRSparseData();
		List<Float> tlabels = new ArrayList<>();
		List<Float> tdata = new ArrayList<>();
		List<Long> theaders = new ArrayList<>();
		List<Integer> tindex = new ArrayList<>();
		theaders.add(0l);
		theaders.add((long) treemap.size());

		for (Entry<Integer, Double> map : treemap.entrySet()) {

			BigDecimal b = new BigDecimal(Double.toString(map.getValue()));
			tdata.add(b.floatValue());
			tindex.add(Integer.valueOf(map.getKey()));
		}

		spData.labels = ArrayUtils.toPrimitive(tlabels
				.toArray(new Float[tlabels.size()]));
		spData.data = ArrayUtils.toPrimitive(tdata.toArray(new Float[tdata
				.size()]));
		spData.colIndex = ArrayUtils.toPrimitive(tindex
				.toArray(new Integer[tindex.size()]));
		spData.rowHeaders = ArrayUtils.toPrimitive(theaders
				.toArray(new Long[theaders.size()]));

		return spData;
	}
	
	public  static double  getClassify(Booster booster,String content,Map<String, Terms> termsmap) throws XGBoostError{
		CSRSparseData spData=getSparseData(content, termsmap);
		if(spData==null){
			return 0.0;
		}
		DMatrix  data = new DMatrix(spData.rowHeaders, spData.colIndex,
				spData.data, DMatrix.SparseType.CSR, 0);
		return booster.predict(data)[0][0];
	}
	

}




发布了233 篇原创文章 · 获赞 151 · 访问量 80万+
展开阅读全文

没有更多推荐了,返回首页

©️2019 CSDN 皮肤主题: 创作都市 设计师: CSDN官方博客

分享到微信朋友圈

×

扫一扫,手机浏览