1. spark-xgboost Java包
主要需要xgboost4j-spark-0.90.jar, xgboost4j-0.90.jar, 以及 调用代码 sparkxgb.zip.
GitHub上面有xgboost java 实现的包,链接:xgboost;
但我省事,用了zhihu xgboost的分布式版本(pyspark)使用测试 的下载链接。
注意,xgboost 的版本号 和sparkxgb内的内容对应。
2. xgboost多分类
我是使用pyspark 运行,通过 pyspark --jars **
把用到的这两个jar包引入。
#!/usr/bin/env python
# -*- coding:utf8 -*-
import os
import sys
import time
import pandas as pd
import numpy as np
from pyspark import SparkConf, SparkContext
import pyspark.sql.types as typ
import pyspark.ml.feature as ft
from pyspark.sql.functions import isnan, isnull,col
import pyspark
from pyspark.sql.session import SparkSession
from pyspark.sql import SQLContext
from pyspark.sql.types import *
from pyspark.ml.feature import StringIndexer,VectorAssembler
from pyspark.ml.linalg import Vectors
from pyspark.ml import Pipeline
from sparkxgb import XGBoostClassifier
import sklearn.datasets as datasets
import numpy as np
import time
def normalize(x):
return (x - np.min(x)) / (np.max(x) - np.min(x))
def get_data():
# input datasets
X, y = datasets.make_blobs(n_samples=100000, centers=10,
n_features=10, random_state=0)
# 归一化
X_norm = normalize(X)
X_train = X_norm[:int(len(X_norm) * 0.8)]
X_test = X_norm[int(len(X_norm) * 0.8):]
y_train = y[:int(len(X_norm) * 0.8)]
y_test = y[int(len(X_norm) * 0.8):]
y_train = y_train.reshape(-1, 1)
# spark df
df = np.concatenate([y_train, X_train], axis=1)
train_df = map(lambda x: (int(x[0]), Vectors.dense(x[1:])), df)
spark_train = spark.createDataFrame(train_df, schema=["label", "features"])
test_df = map(lambda x: (Vectors.dense(x),), X_test)
spark_test = spark.createDataFrame(test_df, schema=["features"])
return spark_train,spark_test,y_train,y_test
def train_model(trainDF):
xgboost = XGBoostClassifier(
featuresCol="features",
labelCol="label",
predictionCol="prediction",
objective='multi:softprob',
numClass=10,
missing=0.0
)
pipeline = Pipeline(stages=[xgboost])
model = pipeline.fit(trainDF)
# # Write model/classifier
# model.write().overwrite().save(hdfstrainpth + "/xgboost_class_test")
# model.load(hdfstrainpth + "/xgboost_class_test")
return model
def test():
data = [1, 2, 3, 4, 5]
distData = sc.parallelize(data)
print("done", distData.collect())
def cal_acc(pred, true):
count = 0
for i,row in enumerate(pred):
pred = row
if pred == true[i]:
count += 1
acc = round(count/len(true), 4)
return acc
if __name__ == "__main__":
from pyspark import SparkContext
conf = SparkConf().set("spark.jars", "/home/xgboost4j-0.90.jar,/home/xgboost4j-spark-0.90.jar")
sc = SparkContext(conf=conf).getOrCreate()
spark = SQLContext(sc)
trainDf, testDf,y_train,y_test = get_data()
print('get df')
model = train_model(trainDf)
prediction = model.transform(testDf).select("prediction").collect()
acc = cal_acc(prediction, y_test)
print("acc:{}".format(acc))
运行结果:acc:0.9992
预测结果:
model.transform(testDf).show()
+--------------------+--------------------+--------------------+----------+
| features| rawPrediction| probability|prediction|
+--------------------+--------------------+--------------------+----------+
|[0.36383649267021...|[0.33353492617607...|[0.06999947130680...| 9.0|
|[0.85080275306445...|[0.33345550298690...|[0.06996602565050...| 2.0|
|[0.54471116142668...|[1.99881935119628...|[0.37008801102638...| 0.0|
|[0.61089833342796...|[0.33345550298690...|[0.06995990127325...| 5.0|
|[0.25437385667790...|[0.33415806293487...|[0.07003305852413...| 6.0|
|[0.47371795998355...|[1.99881935119628...|[0.37008947134017...| 0.0|
|[0.75258857302126...|[0.33345550298690...|[0.07017561793327...| 2.0|
|[0.38430822786126...|[0.33345550298690...|[0.06999430805444...| 9.0|
|[0.84192691973241...|[0.33345550298690...|[0.06999272853136...| 7.0|
|[0.89822104638187...|[0.33345550298690...|[0.06999462842941...| 2.0|
|[0.87335367752325...|[0.33345550298690...|[0.06999401748180...| 2.0|
|[0.34598394310439...|[0.33365276455879...|[0.07000749558210...| 9.0|
|[0.37907532566580...|[0.33345550298690...|[0.06999314576387...| 8.0|
|[0.85996665363900...|[0.33345550298690...|[0.06998810172080...| 7.0|
|[0.52503470825319...|[1.99881935119628...|[0.37008947134017...| 0.0|
|[0.51847376135870...|[0.33345550298690...|[0.06998340785503...| 5.0|
|[0.51366954373353...|[1.98586511611938...|[0.36707320809364...| 0.0|
|[0.38344970186248...|[0.33345550298690...|[0.06998835504055...| 4.0|
|[0.31206934826790...|[0.33353492617607...|[0.06996974349021...| 6.0|
|[0.68235540326326...|[0.33345550298690...|[0.06998881697654...| 1.0|
+--------------------+--------------------+--------------------+----------+
参考: