梯度增强树(GBT)是使用决策树组合的流行回归方法
相对于Random forest 来说GBT在实际应用中,效果更好
直接上代码
package mllib
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.classification.GBTClassifier
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.feature._
import org.apache.spark.sql.SparkSession
/**
* Created by dongdong on 17/7/10.
*/
case class Fearture_One(
cid: String,
population_gender: String,
population_age: Double,
population_registered_gps_city: String,
population_education_nature: String,
population_university_level: String,
sociality_channel_type: String,
action_registered_channel: String,
action_this_month_once_week_average_login_count: Double,
population_censu_city: String,
population_gps_city: String,
population_own_cell_city: String,
population_rank1_cell_city: String,
population_rank1_cell_cnt: Double,
population_rank2_cell_city: String,
population_rank2_cell_cnt: Double,
population_rank3_cell_city: String,
population_rank3_cell_cnt: Double,
population_gps_censu_flag: Double,
population_own_censu_flag: Double,
population_gps_own_flag: Double,
population_own_txl_flag: Double,
population_gps_txl_flag: Double,
population_censu_txl_flag: Double,
population_cnt_7day_province: Double,
population_cnt_7day_city: Double,
population_cnt_login: Double,
population_before_apply_city: String,
population_after_apply_city: String,
population_before_in_apply_address: Double,
population_before_after_apply_address: Double,
population_in_after_apply_address: Double,
population_re_address_steady: String,
population_apply_address_steady: String,
population_score_fake_gps: Double,
population_score_fake_contacts: Double,
text: String,
flag: String
)
object GBT_Profile {
def main(args: Array[String]): Unit = {
val inpath1 = "/Users/ant_git/src/data/user_profile_train/part-00000"
val spark = SparkSession
.builder()
.master("local[3]")
.appName("GBT_Profile")
.getOrCreate()
import spark.implicits._
//read data and transform datafram
val originalData = spark.sparkContext
.textFile(inpath1)
.map(line => {
val arr = line.split("\001")
val cid = arr(0)
val population_gender = arr(3).replace("\\N", "N")
val population_age = arr(4).replace("\\N", "0").toDouble
val population_registered_gps_city = arr(7).replace("\\N", "N")
val population_education_nature = arr(10).replace("\\N", "N")
val population_university_level = arr(11).replace("\\N", "N")
val sociality_channel_type = arr(13).replace("\\N", "N")
val action_registered_channel = arr(44).replace("\\N", "N")
val action_this_month_once_week_average_login_count = arr(54).replace("\\N", "0").toDouble
val population_censu_city = arr(63).replace("\\N", "N")
val population_gps_city = arr(64).replace("\\N", "N")
// val population_jz_city = arr(65).replace("\\N", "N")
// val population_ip_city = arr(66).replace("\\N", "N")
val population_own_cell_city = arr(67).replace("\\N", "N")
val population_rank1_cell_city = arr(68).replace("\\N", "N")
val population_rank1_cell_cnt = arr(69).replace("\\N", "0").toDouble
val population_rank2_cell_city = arr(70).replace("\\N", "N")
val population_rank2_cell_cnt = arr(71).replace("\\N", "0").toDouble
val population_rank3_cell_city = arr(72).replace("\\N", "N")
val population_rank3_cell_cnt = arr(73).replace("\\N", "0").toDouble
//val population_jxl_call_max_city = arr(74).replace("\\N", "N")
// val population_jxl_call_max_city_cnt = arr(75).replace("\\N", "0").toDouble
//val population_anzhuo_30day_max_city = arr(76).replace("\\N", "N")
//val population_anzhuo_30day_max_city_cnt = arr(77).replace("\\N", "0").toDouble
val population_gps_censu_flag = arr(78).replace("\\N", "0").toDouble
//val population_gps_jxl_flag = arr(79).replace("\\N", "0").toDouble
//val population_gps_jz_flag = arr(80).replace("\\N", "0").toDouble
//val population_ip_censu_flag = arr(81).replace("\\N", "0").toDouble
// val population_ip_jxl_flag = arr(82).replace("\\N", "0").toDouble
//val population_ip_jz_flag = arr(83).replace("\\N", "0").toDouble
val population_own_censu_flag = arr(84).replace("\\N", "0").toDouble
//val population_own_jxl_flag = arr(85).replace("\\N", "0").toDouble
//val population_own_jz_flag = arr(86).replace("\\N", "0").toDouble
val population_gps_own_flag = arr(87).replace("\\N", "0").toDouble
//val population_gps_ip_flag = arr(88).replace("\\N", "0").toDouble
//val population_ip_own_flag = arr(89).replace("\\N", "0").toDouble
//val population_ip_txl_flag = arr(90).replace("\\N", "0").toDouble
val population_own_txl_flag = arr(91).replace("\\N", "0").toDouble
val population_gps_txl_flag = arr(92).replace("\\N", "0").toDouble
val population_censu_txl_flag = arr(93).replace("\\N", "0").toDouble
//val population_jxl_txl_flag = arr(94).replace("\\N", "0").toDouble
//val population_jz_txl_flag = arr(95).replace("\\N", "0").toDouble
val population_cnt_7day_province = arr(96).replace("\\N", "0").toDouble
val population_cnt_7day_city = arr(97).replace("\\N", "0").toDouble
val population_cnt_login = arr(102).replace("\\N", "0").toDouble
val population_before_apply_city = arr(107).replace("\\N", "N")
val population_after_apply_city = arr(108).replace("\\N", "N")
val population_before_in_apply_address = arr(111).replace("\\N", "0").toDouble
val population_before_after_apply_address = arr(112).replace("\\N", "0").toDouble
val population_in_after_apply_address = arr(113).replace("\\N", "0").toDouble
val population_re_address_steady = arr(116).replace("\\N", "N")
val population_apply_address_steady = arr(117).replace("\\N", "N")
val population_score_fake_gps = arr(127).replace("\\N", "0").toDouble
val population_score_fake_contacts = arr(128).replace("\\N", "0").toDouble
val text = population_gender + "|" +
population_registered_gps_city + "|" +
population_education_nature + "|" +
population_university_level + "|" +
sociality_channel_type + "|" +
action_registered_channel + "|" +
population_censu_city + "|" +
population_gps_city + "|" +
population_own_cell_city + "|" +
population_rank1_cell_city + "|" +
population_rank2_cell_city + "|" +
population_rank3_cell_city + "|" +
population_before_apply_city + "|" +
population_after_apply_city + "|" +
population_re_address_steady + "|" +
population_apply_address_steady
val flag = arr(141)
Fearture_One(
cid: String,
population_gender: String,
population_age: Double,
population_registered_gps_city: String,
population_education_nature: String,
population_university_level: String,
sociality_channel_type: String,
action_registered_channel: String,
action_this_month_once_week_average_login_count: Double,
population_censu_city: String,
population_gps_city: String,
population_own_cell_city: String,
population_rank1_cell_city: String,
population_rank1_cell_cnt: Double,
population_rank2_cell_city: String,
population_rank2_cell_cnt: Double,
population_rank3_cell_city: String,
population_rank3_cell_cnt: Double,
population_gps_censu_flag: Double,
population_own_censu_flag: Double,
population_gps_own_flag: Double,
population_own_txl_flag: Double,
population_gps_txl_flag: Double,
population_censu_txl_flag: Double,
population_cnt_7day_province: Double,
population_cnt_7day_city: Double,
population_cnt_login: Double,
population_before_apply_city: String,
population_after_apply_city: String,
population_before_in_apply_address: Double,
population_before_after_apply_address: Double,
population_in_after_apply_address: Double,
population_re_address_steady: String,
population_apply_address_steady: String,
population_score_fake_gps: Double,
population_score_fake_contacts: Double,
text: String,
flag: String
)
}
).toDS
//label to indexer
val labelIndexer = new StringIndexer()
.setInputCol("flag")
.setOutputCol("indexedLabel")
.fit(originalData)
//splits words
val tokenizer = new RegexTokenizer()
.setInputCol("text")
.setOutputCol("words")
.setPattern("\\|")
//words to vector
val word2Vec = new Word2Vec()
.setInputCol("words")
.setOutputCol("word2feature")
.setVectorSize(100)
//.setMinCount(1)
.setMaxIter(10)
//array fields
val arr = Array("population_age",
"action_this_month_once_week_average_login_count",
"population_rank1_cell_cnt",
"population_rank2_cell_cnt",
"population_rank3_cell_cnt",
"population_gps_censu_flag",
"population_own_censu_flag",
"population_gps_own_flag",
"population_own_txl_flag",
"population_gps_txl_flag",
"population_censu_txl_flag",
"population_cnt_7day_province",
"population_cnt_7day_city",
"population_cnt_login",
"population_before_in_apply_address",
"population_before_after_apply_address",
"population_in_after_apply_address",
"population_score_fake_gps",
"population_score_fake_contacts",
"word2feature"
)
//merge fields to Verctor
val vectorAssembler = new VectorAssembler()
.setInputCols(arr)
.setOutputCol("assemblerVector")
//creat GBT
val gbt = new GBTClassifier()
.setLabelCol("indexedLabel")
.setFeaturesCol("assemblerVector")
//set iterator
.setMaxIter(25)
//set tree depth
.setMaxDepth(5)
val labelConverter = new IndexToString()
.setInputCol("prediction")
.setOutputCol("predictedLabel")
.setLabels(labelIndexer.labels)
val Array(trainingData, testData) = originalData.randomSplit(Array(0.8, 0.2))
val pipeline = new Pipeline().setStages(Array(labelIndexer, tokenizer, word2Vec, vectorAssembler, gbt, labelConverter))
val model = pipeline.fit(originalData)
val predictionResultDF = model.transform(testData)
predictionResultDF.show(false)
val label_1 = predictionResultDF.select("cid", "flag", "predictedLabel")
.filter($"flag" === 1)
.count()
val correct_1 = predictionResultDF.select("cid", "flag", "predictedLabel")
.filter($"flag" === $"predictedLabel")
.filter($"predictedLabel" === 1).count()
val correct_0 = predictionResultDF.select("cid", "flag", "predictedLabel")
.filter($"flag" === $"predictedLabel")
.filter($"predictedLabel" === 0).count()
val predicted_1 = predictionResultDF.select("cid", "predictedLabel")
.filter($"predictedLabel" === 1)
.repartition(1).write.format("csv").save("/Users/ant_git/Antifraud/src/data/predict/")
val evaluator = new MulticlassClassificationEvaluator()
.setLabelCol("indexedLabel")
.setPredictionCol("prediction")
.setMetricName("accuracy")
val accuracy = evaluator.evaluate(predictionResultDF)
val error = 1.0 - accuracy
println("Test Error = " + (1.0 - accuracy))
spark.stop()
}
}
总结:算法是别人封装好的,最重要的是特征如何进行处理,好的特征,很简单的算法都可以进行分类,不好的特征,再好的模型也很难有好的效果,所以如何进行特征的选择,对于机器学习来说是非常重要的