K-means算法分析,案例(大数据的用户分组召回)

一、K-means算法

1、简介

k-means 算法是一种基于划分的聚类算法,它以 k 为参数,把 n 个数据对象分成 k 个簇,使簇内具有较高的相似度,而簇间的相似度较低。

K-Means是被应用的最广泛的基于划分的聚类算法,是一种硬聚类算法,属于典型的局域原型的目标函数聚类的代表。算法首先随机选择k个对象,每个对象初始地代表一个簇的平均值或者中心。对于剩余的每个对象,根据其到各个簇中心的距离,把他们分给距离最小的簇中心,然后重新计算每个簇平均值。重复这个过程,直到聚类准则则函数收敛。准则函数一般采用两种方式:第一种是全局误差函数,第二种是前后两次中心误差变化。

与分类不同,分类是监督学习,要求分类前明确各个类别,并断言每个元素映射到一个类别,而聚类是观察式学习,在聚类前可以不知道类别甚至不给定类别数量,是无监督学习的一种。目前聚类广泛应用于统计学、生物学、数据库技术和市场营销等领域,相应的算法也非常的多。

K-Means属于无监督学习,最大的特别和优势在于模型的建立不需要训练数据。在日常工作中,很多情况下没有办法事先获取到有效的训练数据,这时采用K-Means是一个不错的选择。但K-Means需要预先设置有多少个簇类(K值),这对于像计算某省份全部电信用户的交往圈这样的场景就完全的没办法用K-Means进行。对于可以确定K值不会太大但不明确精确的K值的场景,可以进行迭代运算,然后找出cost最小时所对应的K值,这个值往往能较好的描述有多少个簇类。

2、基本思想及工作原理

基本思想

k-means 算法是根据给定的 n 个数据对象的数据集,构建 k 个划分聚类的方法,每个划分聚类即为一个簇。该方法将数据划分为 n 个簇,每个簇至少有一个数据对象,每个数据对象必须属于而且只能属于一个簇。同时要满足同一簇中的数据对象相似度高,不同簇中的数据对象相似度较小。聚类相似度是利用各簇中对象的均值来进行计算的。

k-means 算法的处理流程如下。首先,随机地选择 k 个数据对象,每个数据对象代表一个簇中心,即选择 k 个初始中心;对剩余的每个对象,根据其与各簇中心的相似度(距离),将它赋给与其最相似的簇中心对应的簇;然后重新计算每个簇中所有对象的平均值,作为新的簇中心。

不断重复以上这个过程,直到准则函数收敛,也就是簇中心不发生明显的变化。通常采用均方差作为准则函数,即最小化每个点到最近簇中心的距离的平方和。

新的簇中心计算方法是计算该簇中所有对象的平均值,也就是分别对所有对象的各个维度的值求平均值,从而得到簇的中心点。例如,一个簇包括以下 3 个数据对象 {(6,4,8),(8,2,2),(4,6,2)},则这个簇的中心点就是 ((6+8+4)/3,(4+2+6)/3,(8+2+2)/3)=(6,4,4)。

k-means 算法使用距离来描述两个数据对象之间的相似度。距离函数有明式距离、欧氏距离、马式距离和兰氏距离,最常用的是欧氏距离。
k-means 算法是当准则函数达到最优或者达到最大的迭代次数时即可终止。当采用欧氏距离时,准则函数一般为最小化数据对象到其簇中心的距离的平方和,即在这里插入图片描述

其中,k 是簇的个数,Ci是第 i 个簇的中心点,dist(Ci,x)为 X 到Ci的距离。

工作原理

1.从数据集合D中随机选择k个对象作为初始簇中心。

2.根据簇的中心值,把数据集合中的n个对象全部分给最“相似”的簇(“相似”根据距离长短来判断)。

3.根据簇的中心值,重新计算每个簇的中心值。

4.计算准则函数。

5.若准则函数满足阈值则退出,否则返回第二步继续。
在这里插入图片描述

3、简单案例分析

数据:玩家信息(月)

玩家(ID)游戏时间(小时)充值金额(元)
16055
29086
33022
41511
5288300
6223200
700
8145
9320280
106555
11130
121018
13115108
1430
155240
166276
177380
184530
1910
20180166

数据抽象为如下,含义为 游戏时间(小时),充值金额(元)
在这里插入图片描述

把玩家分为3类:

1.优质用户(高时长,高消费)

2.普通玩家(在线时长中等,消费中等)

3.不活跃用户 (在线时间短,消费低)

流程图
在这里插入图片描述
测试代码

import org.apache.spark.mllib.clustering.KMeans
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.{SparkConf, SparkContext}

object KMeansTest {
  def main(args: Array[String]) {
      val conf = new SparkConf()
      val sc = new SparkContext(conf)

    val data =sc.textFile(args(0))
    val parsedData =data.map(s => Vectors.dense(s.split(' ').map(_.trim.toDouble))).cache()

    //设置簇的个数为3
    val numClusters =3
    //迭代20次
    val numIterations= 20
    //运行10次,选出最优解
    val runs=10
    val clusters =KMeans.train(parsedData, numClusters, numIterations,runs)
    // Evaluateclustering by computing Within Set Sum of Squared Errors
    val WSSSE = clusters.computeCost(parsedData)
    println("WithinSet Sum of Squared Errors = " + WSSSE)

    val a21 =clusters.predict(Vectors.dense(57.0,30.0))
    val a22 =clusters.predict(Vectors.dense(0.0,0.0))

    //打印出中心点
    println("Clustercenters:");
    for (center <-clusters.clusterCenters) {
      println(" "+ center)
    }

    //打印出测试数据属于哪个簇
    println(parsedData.map(v=> v.toString() + " belong to cluster :" +clusters.predict(v)).collect().mkString("\n"))
    println("预测第21个用户的归类为-->"+a21)
    println("预测第22个用户的归类为-->"+a22)
  }
}

输出结果说明
在这里插入图片描述
可以明显的看到:

1类用户为优质用户

2类用户为普通用户

3类用户为不活跃用户

在这里插入图片描述
21个用户的数据为(57,30)

22个的用户数据为(0,0)

分类是正确的
在这里插入图片描述
三个簇的聚集中心

4、优缺点

k-means 聚类算法是一种经典算法,该算法简单高效,易于理解和实现;算法的时间复杂度低,为O(tkm),其中,r 为迭代次数,k 为簇的数目,m 为记录数,n 为维数,并且 t<<m k<<n。

k-means 算法也有许多不足的地方。

  • 需要人为事先确定簇的个数,k 的选择往往会是一个比较困难的问题。
  • 对初始值的设置很敏感,算法的结果与初始值的选择有关。
  • 对噪声和异常数据非常敏感。如果某个异常值具有很大的数值,则会严重影响数据分布。
  • 不能解决非凸形状的数据分布聚类问题。
  • 主要用于发现圆形或者球形簇,不能识别非球形的簇。

二、Spark MLlib实现K-means算法

1、相关参数和构造方法

Spark MLlib 中的 k-means 算法的实现类 KMeans 具有以下参数。

  	private var k: int,
    private var maxiterations: Int,
    private var runs: Int,
    private var initializationMode String
    private var initializationStep: Int,
    private var epsilon: Double,
    private var seed: Long) extends: Serializable with Logging

参数的含义解释

名称说明
k表示期望的聚类的个数。
maxIterations表示方法单次运行的最大迭代次数。
runs表示算法被运行的次数。k-means 算法不保证能返回全局最优的聚类结果,所以在目标数据集上多次跑 k-means 算法,有助于返回最佳聚类结果。
initializationMode表示初始聚类中心点的选择方式,目前支持随机选择或者 K_MEANS_PARALLEL 方式,默认是 K_MEANS_PARALLEL。
initializationsteps表示 K_MEANS_PARALLEL 方法中的步数。
epsilon表示 k-means 算法迭代收敛的阈值。
seed表示集群初始化时的随机种子。

MLlib 的 k-means 构造函数
使用默认值构造 MLlib 的 k-means 实例的接口如下。

{k: 2,maxIterations: 20,runs: 1, initializationMode: KMeans.K_MEANS_PARALLEL,InitializationSteps: 5,epsilon: le-4,seed:random}

2、MLlib 中的 k-means 训练函数

MLlib 中的 k-means 训练函数 KMeans.train 方法有很多重载方法,这里以参数最全的一个 来进行说明。KMeans.train 方法如下。

def train(
    data:RDD[Vector],
    k:Int
    maxIterations:Int
    runs:Int
    initializationMode: String,
    seed: Long): KMeansModel = {
        new KMeans().setK(k)    -
        .setMaxIterations(maxIterations)
        .setRuns(runs)
        .setInitializatinMode(initializationMode)
        .setSeed(seed)
        .run(data)
    }
)

方法中各个参数的含义与构造函数相同,这里不再重复。

3、MLlib 中的 k-means 的预测函数

MLlib 中的 k-means 的预测函数 KMeansModel.predict 方法接收不同格式的数据输入参数,可以是向量或者 RDD,返回的是输入参数所属的聚类的索引号。KMeansModel.predict 方法的 API 如下。

def predict(point:Vector):Int
def predict(points:RDD[Vector]):RDD[int]

第一种预测方法只能接收一个点,并返回其所在的簇的索引值;第二个预测方法可以接收一组点,并把每个点所在簇的值以 RDD 方式返回。

三、Spark ml实现k-means实例

Mysql数据库的表:
用户表:customs
商品表:goods
订单表:orders
订单详情表:ordertIems
实现对已有数据进行用户分组召回:

1、数据库连接的方法

连接数据库中表的方法(表名由参数传入):

  def readMysql(spark:SparkSession,tableName:String) = {
    val map:Map[String,String] = Map[String,String](
      elems = "url"->"jdbc:mysql://192.168.109.138:3306/myshops",
      "driver"->"com.mysql.jdbc.Driver",
      "user"->"root",
      "password"->"root",
      "dbtable"->tableName
    )
    spark.read.format("jdbc").options(map).load()
  }

2、自定义函数

业务处理需要用到的一些自定义函数

  val func_membership = udf{
    (score:Int)=> {
      score match {
        case i if i<100 => 1
        case i if i<500 => 2
        case i if i<1000 => 3
        case _ => 4
      }
    }
  }

  val func_bir = udf {
    (idno: String, now: String) => {
      var year = idno.substring(6, 10).toInt
      var month = idno.substring(10, 12).toInt
      val day = idno.substring(12, 14).toInt


      val dts = now.split("-")
      val nowYear = dts(0).toInt
      val nowMonth = dts(1).toInt
      val nowDay = dts(2).toInt

      if (nowMonth > month) {
        nowYear - year
      } else if (nowMonth < month) {
        nowYear - 1 - year
      } else {
        if (nowDay >= day) {
          nowYear - year
        } else {
          nowYear - 1 - year
        }
      }
    }
  }

  val func_claAge = udf{
    (bir:Int)=>{
      bir match {
        case i if i<10 => 1
        case i if i<18 => 2
        case i if i<23 => 3
        case i if i<35 => 4
        case i if i<50 => 5
        case i if i<70 => 6
        case _ => 7
      }
    }
  }

  val func_userScore = udf{
    (score:Int)=>{
      score match {
        case i if i<100 => 1
        case i if i<500 => 2
        case _ => 3
      }
    }
  }

  val func_logcount = udf{
    (score:Int)=>{
      score match {
        case i if i<500 => 1
        case _ => 2
      }
    }
  }

3、数据清洗

分析表中的数据,将和分组召回有关的数据列分析出来,创建表:

def main(args: Array[String]): Unit = {
 val spark = SparkSession.builder().appName("db").master("local[*]").getOrCreate()

    //用户表
    /**
     * 用户表中需要的列和相关的数据清洗
     * 用户id cust——id
     * 公司 company(字符串转数字)
     * 省份 province_id
     * 市city_id
     * 县区 district_id
     * 会员等级 membership_level(分段处理)商业业务(0-100 1,100-500 2,500-1000 3 ,1000- 4)
     * 创建日期 create_at(max-min=day)
     * 最后登录日期 last_login_time(max-min=day)
     * 存活状态 active(过滤) 0已销户 1存在
     * 身份证号 idno -->先计算再分age(0-10,1 11-18 2,19-23 3,24-35 4,34-50 5,51-70 6,71- 7)
     * 积分 biz_point(0-100 1,100-500 2,500 3)
     * 性别 sex 0 1(不存在,通过分析身份证号得出) 
     * 婚姻状态 marital_status 0未婚 1已婚 
     * 学历 education_id
     * 登录次数 login_count(0-500 0,500- 1) 
     * 行业 vocation
     * 职位 post
     */
    val userTable = readMysql(spark,"customs")
    //过滤出存活的用户
      .filter("active!=0")
      .select("cust_id","company","province_id","city_id","district_id"
        ,"membership_level","create_at","last_login_time","active","idno","biz_point"
        ,"sex","marital_status","education_id","login_count","vocation","post")
    //商品表
    val goodTable = readMysql(spark,"goods").select("good_id","price")
    //订单表
    val orderTable = readMysql(spark,"orders").select("ord_id","cust_id")
    //订单明细表
    val ordertailTable = readMysql(spark,"orderItems").select("ord_id","good_id","buy_num")

    //先将公司名 通过StringIndexer转为数字
    val compIndex = new StringIndexer().setInputCol("company").setOutputCol("compId")
    //使用一个自定义UDF函数

4、业务处理

相关业务处理:

//先将公司名 通过StringIndexer转为数字
    val compIndex = new StringIndexer().setInputCol("company").setOutputCol("compId")
    //导入隐式类
    import  spark.implicits._

    //计算每个用户购买的次数
    val tmp_bc = orderTable.groupBy("cust_id").agg(count($"ord_id").as("buycount"))

    //计算每个用户在网站上花费了多少钱
   val tmp_pay = orderTable.join(ordertailTable,Seq("ord_id"),"inner")
      .join(goodTable,Seq("good_id"),"inner")
      .groupBy("cust_id")
      .agg(sum($"buy_num"*$"price").as("pay"))


    val df = compIndex.fit(userTable).transform(userTable)
       //会员积分等级
      .withColumn("mslevel", func_membership($"membership_level"))
      //将字符串类型的时间转化为时间格式
      .withColumn("reg_date", datediff($"create_at", min($"create_at").over()))
      .withColumn("last_login_date", datediff($"last_login_time", min($"last_login_time").over()))
      //通过身份证号算出年龄
      .withColumn("bir", func_claAge(func_bir($"idno", lit(current_date()))))
      //用户等级
      .withColumn("user_score", func_userScore($"biz_point"))
      //用户登录次数
      .withColumn("logcount", func_logcount($"login_count"))
      //连接两张表
      .join(tmp_bc, Seq("cust_id"), "left")
      .join(tmp_pay, Seq("cust_id"), "left")
      //将空值装换为0
      .na.fill(0)
      /删除被清洗过的,无意义的列
      .drop("company", "membership_level"
        , "create_at", "idno", "biz_point", "login_count","last_login_time")
//将所有列的类型强行转为double类型
//方便后面的绘图和k-means算法的使用
    val columns = df.columns.map(f => col(f).cast(DoubleType))
    val num_fmt = df.select(columns: _*)
//将除了第一列外的所有列都组装成一个向量列
    val va= new VectorAssembler().setInputCols(
      Array("cust_id","province_id","city_id","district_id"
        ,"active","sex","marital_status","education_id"
        ,"vocation","post","compId","mslevel","reg_date"
        ,"last_login_date","bir","user_score","logcount"
        ,"buycount","pay")).setOutputCol("orign_feature")
    val ofdf = va.transform(num_fmt).select("cust_id", "orign_feature")
//将原始特征列归一化处理,有利于收敛,感兴趣的可以自行百度
    val mmScaler = new MinMaxScaler().setInputCol("orign_feature").setOutputCol("feature")
    val resdf = mmScaler.fit(ofdf).transform(ofdf)
      .select("cust_id","feature").cache()
      //输出显示
      // val resdf = mmScaler.fit(ofdf).transform(ofdf).select("cust_id","feature").show(false)

+-------+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|cust_id|feature                                                                                                                                                                                                                                                                  |
+-------+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|496.0  |[0.004920393631490519,0.48484848484848486,0.4749262536873156,0.4754339511621065,0.5,0.0,0.0,1.0,0.9859719438877755,0.9851940776310524,0.6229854955680902,0.3333333333333333,0.5067753077672482,0.3793429003021148,0.5,1.0,0.0,0.14285714285714285,0.08377923480794836]   |
|833.0  |[0.008290663253060245,0.8787878787878788,0.8643067846607669,0.8634892615475139,0.5,0.0,0.0,0.25,0.905811623246493,0.9051620648259304,0.5506647864625303,0.3333333333333333,0.5067753077672482,0.7931457703927492,0.0,1.0,0.0,0.2857142857142857,0.08109126535533952]     |
|1088.0 |[0.01084086726938155,0.6363636363636364,0.6371681415929203,0.6363636363636364,0.5,0.0,0.0,0.75,0.8877755511022044,0.8879551820728291,0.21555197421434327,0.6666666666666666,0.48516517580373747,0.7585913897280967,0.5,0.0,0.0,0.0,0.0]                                  |
|1580.0 |[0.01576126090087207,0.5151515151515151,0.5191740412979351,0.5198587819947044,0.5,0.0,0.0,0.5,0.657314629258517,0.657062825130052,0.4879129734085415,0.6666666666666666,0.42848286245682404,0.8620657099697885,0.5,1.0,0.0,0.5714285714285714,0.49068421801860646]       |
|1591.0 |[0.015871269701576127,0.21212121212121213,0.20648967551622419,0.20623712856722565,0.5,1.0,0.0,0.75,0.969939879759519,0.9699879951980792,0.7966357775987107,0.0,0.20467629085112035,0.7585913897280967,0.0,1.0,0.0,0.2857142857142857,0.12288458478827007]                |
|1829.0 |[0.018251460116809344,0.24242424242424243,0.25663716814159293,0.25831126802000587,0.5,0.0,1.0,0.0,1.0,0.9987995198079231,0.2694399677679291,0.6666666666666666,0.7250022141528651,0.20685422960725075,0.0,1.0,0.0,0.0,0.0]                                               |
|2366.0 |[0.023621889751180094,0.21212121212121213,0.2153392330383481,0.21535745807590467,0.5,0.0,0.0,0.75,0.6372745490981964,0.636654661864746,0.4395648670427075,0.6666666666666666,0.6711540164732973,0.06891993957703928,0.0,0.0,0.0,0.0,0.0]                                 |
|3749.0 |[0.03745299623969917,0.36363636363636365,0.35398230088495575,0.353927625772286,0.5,1.0,1.0,1.0,0.3066132264529058,0.30612244897959184,0.011583400483481063,0.3333333333333333,0.956868302187583,0.20685422960725075,0.5,0.0,0.0,0.14285714285714285,0.16570787382453672] |
|3794.0 |[0.037903032242579404,0.5757575757575758,0.5752212389380531,0.5760517799352751,0.5,1.0,1.0,1.0,0.7755511022044088,0.7759103641456583,0.01863416599516519,0.6666666666666666,0.36923213178637854,0.7585913897280967,0.5,0.0,0.0,0.0,0.0]                                  |
|4101.0 |[0.04097327786222898,0.9393939393939394,0.9174041297935103,0.9173286260664901,0.5,0.0,0.0,1.0,0.36472945891783565,0.3641456582633053,0.4894238517324738,0.3333333333333333,0.8840669559826411,0.7585913897280967,0.5,1.0,0.0,0.0,0.0]                                    |
|4900.0 |[0.04896391711336907,0.5151515151515151,0.504424778761062,0.5045601647543395,0.5,0.0,1.0,0.75,0.280561122244489,0.2801120448179272,0.11714343271555197,0.0,0.5551324063413338,0.5861971299093656,0.0,1.0,0.0,0.0,0.0]                                                    |
|4935.0 |[0.04931394511560925,0.45454545454545453,0.4631268436578171,0.46337157987643424,0.5,0.0,0.0,0.25,0.04408817635270541,0.044017607042817125,0.11956083803384368,0.3333333333333333,0.8247276591975911,0.7241314199395771,0.5,0.0,0.0,0.2857142857142857,0.1707153683188049]|
|5300.0 |[0.05296423713897112,0.12121212121212122,0.12979351032448377,0.13033245072080024,0.5,1.0,0.0,0.0,0.14428857715430862,0.1452581032412965,0.28031829170024175,0.6666666666666666,0.3746346647772562,0.6896714501510574,0.5,0.0,0.0,0.0,0.0]                                |
|5518.0 |[0.05514441155292423,0.7575757575757576,0.7374631268436578,0.736098852603707,0.5,1.0,1.0,0.75,0.9458917835671342,0.9443777511004402,0.4545729250604351,0.3333333333333333,0.7331502966964839,0.8620657099697885,0.5,1.0,0.0,0.0,0.0]                                     |
|5803.0 |[0.05799463957116569,0.9696969696969697,0.9646017699115044,0.9620476610767873,0.5,1.0,0.0,0.75,0.4849699398797595,0.4841936774709884,0.4734085414987913,0.6666666666666666,0.7088831812948366,0.44826283987915405,0.5,0.0,0.0,0.14285714285714285,0.0057946257358421046] |
|6466.0 |[0.06462517001360109,0.7272727272727273,0.7315634218289085,0.7302147690497205,0.5,0.0,1.0,0.75,0.3486973947895792,0.34973989595838334,0.15783642224012892,0.0,0.382782747320875,0.6552114803625377,0.5,0.0,0.0,0.2857142857142857,0.2529119669069929]                    |
|6620.0 |[0.06616529322345788,0.2727272727272727,0.2890855457227139,0.2900853192115328,0.5,0.0,0.0,0.0,0.8416833667334669,0.8419367747098839,0.20114826752618856,0.6666666666666666,0.471614560269241,0.20685422960725075,0.0,0.0,0.0,0.0,0.0]                                    |
|6654.0 |[0.06650532042563405,0.030303030303030304,0.0471976401179941,0.04766107678729038,0.5,1.0,1.0,0.25,0.9258517034068137,0.9255702280912365,0.3037872683319903,0.3333333333333333,0.7195996811619875,0.3793429003021148,0.5,1.0,0.0,0.14285714285714285,0.0]                 |
|6658.0 |[0.06654532362589007,0.3939393939393939,0.3893805309734513,0.3901147396293027,0.5,1.0,0.0,0.0,0.7314629258517034,0.7306922769107643,0.9811643835616438,0.6666666666666666,0.4824196262509964,0.7585913897280967,0.0,1.0,0.0,0.2857142857142857,0.1025364048199228]       |
|7240.0 |[0.07236578926314105,0.6060606060606061,0.6135693215339233,0.6122388937922919,0.5,0.0,1.0,0.0,0.5050100200400801,0.5046018407362945,0.33340048348106366,0.6666666666666666,0.4500929944203348,0.8620657099697885,0.5,1.0,0.0,0.0,0.0]                                    |
+-------+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
only showing top 20 rows

5、绘图分析质心点数

这里需要用绘图用的方法ChartPanel

重写绘图窗口方法

/*
制作一个图像窗口
 */
class LineGraph(appName:String) extends  ApplicationFrame(appName:String) {
  def this(appName:String,chartTitle: String,list:ListBuffer[Double]) {
    this(appName)
    val lineChart = ChartFactory.createLineChart(chartTitle,
      "质点", "距离", createDataset(list:ListBuffer[Double]),
      PlotOrientation.VERTICAL, true, true, false)

    //设置窗口字体
    val font = new Font("宋体", Font.BOLD, 20)
    lineChart.getTitle.setFont(font)//设置标题字体
    lineChart.getLegend.setItemFont(font)//设置标签字体
    lineChart.getCategoryPlot.getDomainAxis.setLabelFont(font)//设置x轴字体
    lineChart.getCategoryPlot.getRangeAxis.setLabelFont(font)//设置y轴字体

    val chartPanel = new ChartPanel(lineChart)
    chartPanel.setPreferredSize(new java.awt.Dimension(1600, 1200))
    setContentPane(chartPanel)
  }

  def createDataset(list:ListBuffer[Double]): DefaultCategoryDataset = {
    val dataset = new DefaultCategoryDataset();
    var point =2;
    for (dst<-list){
     dataset.addValue(dst,"dist",point+"")
      point+=1
    }
    dataset
  }

}

分析质心数

用k-means算法循环计算出质心数和平均距离,绘制的图找出适当的质心数,(找出之后这个方法就没用了,可以注释掉)
这东西非常吃内存,破笔记本差点没撑住= _ =,用了很长时间才运行完


//计算根据不同的质心点计算所有的距离
//记录不同质心点距离的集合
    val disLst:ListBuffer[Double] = ListBuffer[Double]()
    for(i<-2 to 35){
      val kms = new KMeans()
        .setFeaturesCol("feature")
        .setK(i)

      val model = kms.fit(resdf)
      disLst.append(model.computeCost(resdf))
    }

    //调用绘图工具绘图


    val chart = new LineGraph("app","Kmeans质心距离",disLst)
    chart.pack()
    RefineryUtilities.centerFrameOnScreen(chart)
    chart.setVisible(true)
    
    kms.fit(resdf).transform(resdf).show(false)

在这里插入图片描述
根据图分析出,大致在40左右比较平滑(想看清的可以把35改更大一点,当然得祈祷你计算机撑得住)

6、分组召回

    //使用Kmeans算法进行分组
    val kms = new KMeans()
            .setFeaturesCol("feature")
            .setK(40)
    val user_group_tab = kms.fit(resdf).transform(resdf).drop("feature")
      .withColumnRenamed("prediction","groups")
    //获取每组热门商品的前30名
    val rank = 30
    val wnd = Window.partitionBy("groups").orderBy(desc("group_buy_count"))
    user_group_tab
      .join(orderTable,Seq("cust_id"),"left")
      .join(ordertailTable,Seq("ord_id"),"left")
      .na.fill(0)
      .groupBy("groups","good_id")
      .agg(count("ord_id").as("group_buy_count"))
      .withColumn("rank",row_number().over(wnd))
        .filter($"rank"<=rank)
        .show(false)



    spark.close()
  }

//0的哪一项是补空时的问题
+------+-------+---------------+----+
|groups|good_id|group_buy_count|rank|
+------+-------+---------------+----+
|31    |0      |694            |1   |
|31    |2974   |3              |2   |
|31    |28528  |3              |3   |
|31    |21243  |3              |4   |
|31    |21610  |3              |5   |
|31    |2246   |3              |6   |
|31    |8016   |3              |7   |
|31    |14912  |3              |8   |
|31    |16435  |3              |9   |
|31    |15391  |3              |10  |
|31    |21871  |3              |11  |
|31    |23981  |3              |12  |
|31    |21777  |3              |13  |
|31    |9289   |3              |14  |
|31    |8839   |3              |15  |
|31    |6205   |3              |16  |
|31    |27560  |2              |17  |
|31    |29696  |2              |18  |
|31    |20833  |2              |19  |
|31    |15421  |2              |20  |
+------+-------+---------------+----+
only showing top 20 rows

参考文章:
https://www.cnblogs.com/wuwuwu/p/6162601.html
http://c.biancheng.net/view/3708.html

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值