Spark MLlib实现的广告点击预测–Gradient-Boosted Trees

127 篇文章 14 订阅

关键字:spark、mllib、Gradient-Boosted Trees、广告点击预测

本文尝试使用Spark提供的机器学习算法 Gradient-Boosted Trees来预测一个用户是否会点击广告。

训练和测试数据使用Kaggle Avazu CTR 比赛的样例数据,下载地址:https://www.kaggle.com/c/avazu-ctr-prediction/data

数据格式如下:

spark

包含24个字段:

  • 1-id: ad identifier
  • 2-click: 0/1 for non-click/click
  • 3-hour: format is YYMMDDHH, so 14091123 means 23:00 on Sept. 11, 2014 UTC.
  • 4-C1 — anonymized categorical variable
  • 5-banner_pos
  • 6-site_id
  • 7-site_domain
  • 8-site_category
  • 9-app_id
  • 10-app_domain
  • 11-app_category
  • 12-device_id
  • 13-device_ip
  • 14-device_model
  • 15-device_type
  • 16-device_conn_type
  • 17~24—C14-C21 — anonymized categorical variables

其中5到15列为分类特征,16~24列为数值型特征。

Spark代码如下:

 
  1. package com.lxw1234.test
  2.  
  3. import scala.collection.mutable.ListBuffer
  4. import scala.collection.mutable.ArrayBuffer
  5.  
  6. import org.apache.spark.SparkContext
  7. import org.apache.spark.SparkContext._
  8. import org.apache.spark.SparkConf
  9. import org.apache.spark.rdd.RDD
  10.  
  11. import org.apache.spark.mllib.classification.NaiveBayes
  12. import org.apache.spark.mllib.regression.LabeledPoint
  13. import org.apache.spark.mllib.linalg.Vectors
  14.  
  15. import org.apache.spark.mllib.tree.GradientBoostedTrees
  16. import org.apache.spark.mllib.tree.configuration.BoostingStrategy
  17. import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel
  18.  
  19. /**
  20.  * By: lxw
  21.  * http://lxw1234.com
  22.  */
  23. object CtrPredict {
  24.  
  25.   //input (1fbe01fe,f3845767,28905ebd,ecad2386,7801e8d9)
  26.   //output ((0:1fbe01fe),(1:f3845767),(2:28905ebd),(3:ecad2386),(4:7801e8d9))
  27.     def parseCatFeatures(catfeatures: Array[String]) :  List[(Int, String)] = {
  28.       var catfeatureList = new ListBuffer[(Int, String)]()
  29.       for (i <- 0 until catfeatures.length){
  30.           catfeatureList += i -> catfeatures(i).toString
  31.       }
  32.       catfeatureList.toList
  33.     }
  34.  
  35.   def main(args: Array[String]) {
  36.       val conf = new SparkConf().setMaster("yarn-client")
  37.       val sc = new SparkContext(conf)
  38.       
  39.       var ctrRDD = sc.textFile("/tmp/lxw1234/sample.txt",10);
  40.       println("Total records : " + ctrRDD.count)
  41.       
  42.       //将整个数据集80%作为训练数据,20%作为测试数据集
  43.       var train_test_rdd = ctrRDD.randomSplit(Array(0.8, 0.2), seed = 37L)
  44.       var train_raw_rdd = train_test_rdd(0)
  45.       var test_raw_rdd = train_test_rdd(1)
  46.       
  47.       println("Train records : " + train_raw_rdd.count)
  48.       println("Test records : " + test_raw_rdd.count)
  49.       
  50.       //cache train, test
  51.       train_raw_rdd.cache()
  52.       test_raw_rdd.cache()
  53.       
  54.       var train_rdd = train_raw_rdd.map{ line =>
  55.           var tokens = line.split(",",-1)
  56.           //key为id和是否点击广告
  57.           var catkey = tokens(0) + "::" + tokens(1)
  58.           //第6列到第15列为分类特征,需要One-Hot-Encoding
  59.           var catfeatures = tokens.slice(5, 14)
  60.           //第16列到24列为数值特征,直接使用
  61.           var numericalfeatures = tokens.slice(15, tokens.size-1)
  62.           (catkey, catfeatures, numericalfeatures)
  63.       }
  64.       
  65.       //拿一条出来看看
  66.       train_rdd.take(1)
  67.       //scala> train_rdd.take(1)
  68.       //res6: Array[(String, Array[String], Array[String])] = Array((1000009418151094273::0,Array(1fbe01fe,
  69.       //            f3845767, 28905ebd, ecad2386, 7801e8d9, 07d7df22, a99f214a, ddd2926e, 44956a24),
  70.       //              Array(2, 15706, 320, 50, 1722, 0, 35, -1)))
  71.       
  72.       //将分类特征先做特征ID映射
  73.       var train_cat_rdd  = train_rdd.map{
  74.         x => parseCatFeatures(x._2)
  75.       }
  76.       
  77.       train_cat_rdd.take(1)
  78.       //scala> train_cat_rdd.take(1)
  79.       //res12: Array[List[(Int, String)]] = Array(List((0,1fbe01fe), (1,f3845767), (2,28905ebd),
  80.       //        (3,ecad2386), (4,7801e8d9), (5,07d7df22), (6,a99f214a), (7,ddd2926e), (8,44956a24)))
  81.       
  82.       //将train_cat_rdd中的(特征ID:特征)去重,并进行编号
  83.       var oheMap = train_cat_rdd.flatMap(x => x).distinct().zipWithIndex().collectAsMap()
  84.       //oheMap: scala.collection.Map[(Int, String),Long] = Map((7,608511e9) -> 31527, (7,b2d8fbed) -> 42207,
  85.       //  (7,1d3e2fdb) -> 52791
  86.       println("Number of features")
  87.       println(oheMap.size)
  88.       
  89.       //create OHE for train data
  90.       var ohe_train_rdd = train_rdd.map{ case (key, cateorical_features, numerical_features) =>
  91.               var cat_features_indexed = parseCatFeatures(cateorical_features)                        
  92.               var cat_feature_ohe = new ArrayBuffer[Double]
  93.               for (k <- cat_features_indexed) {
  94.                 if(oheMap contains k){
  95.                 cat_feature_ohe += (oheMap get (k)).get.toDouble
  96.                 }else {
  97.                   cat_feature_ohe += 0.0
  98.                 }               
  99.               }
  100.               var numerical_features_dbl  = numerical_features.map{
  101.                         x =>
  102.                           var x1 = if (x.toInt < 0) "0" else x
  103.                         x1.toDouble
  104.               }
  105.               var features = cat_feature_ohe.toArray ++  numerical_features_dbl           
  106.               LabeledPoint(key.split("::")(1).toInt, Vectors.dense(features))                                               
  107.      }
  108.       
  109.      ohe_train_rdd.take(1)
  110.      //res15: Array[org.apache.spark.mllib.regression.LabeledPoint] =
  111.      //  Array((0.0,[43127.0,50023.0,57445.0,13542.0,31092.0,14800.0,23414.0,54121.0,
  112.      //     17554.0,2.0,15706.0,320.0,50.0,1722.0,0.0,35.0,0.0]))
  113.      
  114.      //训练模型
  115.      //val boostingStrategy = BoostingStrategy.defaultParams("Regression")
  116.      val boostingStrategy = BoostingStrategy.defaultParams("Classification")
  117.      boostingStrategy.numIterations = 100
  118.      boostingStrategy.treeStrategy.numClasses = 2
  119.      boostingStrategy.treeStrategy.maxDepth = 10
  120.      boostingStrategy.treeStrategy.categoricalFeaturesInfo = Map[Int, Int]()
  121.      
  122.      
  123.      val model = GradientBoostedTrees.train(ohe_train_rdd, boostingStrategy)
  124.      //保存模型
  125.      model.save(sc, "/tmp/myGradientBoostingClassificationModel")
  126.      //加载模型
  127.      val sameModel = GradientBoostedTreesModel.load(sc,"/tmp/myGradientBoostingClassificationModel")
  128.      
  129.      //将测试数据集做OHE
  130.      var test_rdd = test_raw_rdd.map{ line =>
  131.         var tokens = line.split(",")
  132.         var catkey = tokens(0) + "::" + tokens(1)
  133.         var catfeatures = tokens.slice(5, 14)
  134.         var numericalfeatures = tokens.slice(15, tokens.size-1)
  135.         (catkey, catfeatures, numericalfeatures)
  136.      }
  137.      
  138.      var ohe_test_rdd = test_rdd.map{ case (key, cateorical_features, numerical_features) =>
  139.             var cat_features_indexed = parseCatFeatures(cateorical_features)      
  140.             var cat_feature_ohe = new ArrayBuffer[Double]
  141.             for (k <- cat_features_indexed) {               
  142.               if(oheMap contains k){
  143.                 cat_feature_ohe += (oheMap get (k)).get.toDouble
  144.               }else {
  145.                 cat_feature_ohe += 0.0
  146.               }
  147.             }
  148.           var numerical_features_dbl  = numerical_features.map{x =>
  149.                               var x1 = if (x.toInt < 0) "0" else x
  150.                               x1.toDouble}
  151.             var features = cat_feature_ohe.toArray ++  numerical_features_dbl           
  152.             LabeledPoint(key.split("::")(1).toInt, Vectors.dense(features))                                               
  153.      }
  154.      
  155.      //验证测试数据集
  156.      var b = ohe_test_rdd.map {
  157.         y => var s = model.predict(y.features)
  158.         (s,y.label,y.features)
  159.      }
  160.      
  161.      b.take(10).foreach(println)
  162.      
  163.      //预测准确率
  164.       var predictions = ohe_test_rdd.map(lp => sameModel.predict(lp.features))
  165.       predictions.take(10).foreach(println)
  166.       var predictionAndLabel = predictions.zip( ohe_test_rdd.map(_.label))
  167.       var accuracy = 1.0 * predictionAndLabel.filter(x => x._1 == x._2 ).count/ohe_test_rdd.count
  168.       println("GBTR accuracy " + accuracy)
  169.       //GBTR accuracy 0.8227084119200302
  170.     
  171.   }
  172.  
  173. }
  174.  

其中,训练数据集: Train records : 104558, 测试数据集:Test records : 26510

程序主要输出:

 
  1. scala> train_rdd.take(1)
  2. res23: Array[(String, Array[String], Array[String])] = Array((1000009418151094273::0,
  3. Array(1fbe01fe, f3845767, 28905ebd, ecad2386, 7801e8d9, 07d7df22, a99f214a, ddd2926e, 44956a24),
  4. Array(2, 15706, 320, 50, 1722, 0, 35, -1)))
  5.  
  6.  
  7. scala> train_cat_rdd.take(1)
  8. res24: Array[List[(Int, String)]] = Array(List((0,1fbe01fe), (1,f3845767), (2,28905ebd),
  9. (3,ecad2386), (4,7801e8d9), (5,07d7df22), (6,a99f214a), (7,ddd2926e), (8,44956a24)))
  10.  
  11.  
  12. scala> println("Number of features")
  13. Number of features
  14.  
  15. scala> println(oheMap.size)
  16. 57606
  17.  
  18.  
  19. scala> ohe_train_rdd.take(1)
  20. res27: Array[org.apache.spark.mllib.regression.LabeledPoint] = Array(
  21. (0.0,[11602.0,22813.0,11497.0,16828.0,30657.0,23893.0,13182.0,31723.0,39722.0,2.0,15706.0,320.0,50.0,1722.0,0.0,35.0,0.0]))
  22.  
  23.  
  24. scala> println("GBTR accuracy " + accuracy)
  25. GBTR accuracy 0.8227084119200302
  26.  
  27.  

 

  • 0
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值