目录
一、K-means算法
1、简介
k-means 算法是一种基于划分的聚类算法,它以 k 为参数,把 n 个数据对象分成 k 个簇,使簇内具有较高的相似度,而簇间的相似度较低。
K-Means是被应用的最广泛的基于划分的聚类算法,是一种硬聚类算法,属于典型的局域原型的目标函数聚类的代表。算法首先随机选择k个对象,每个对象初始地代表一个簇的平均值或者中心。对于剩余的每个对象,根据其到各个簇中心的距离,把他们分给距离最小的簇中心,然后重新计算每个簇平均值。重复这个过程,直到聚类准则则函数收敛。准则函数一般采用两种方式:第一种是全局误差函数,第二种是前后两次中心误差变化。
与分类不同,分类是监督学习,要求分类前明确各个类别,并断言每个元素映射到一个类别,而聚类是观察式学习,在聚类前可以不知道类别甚至不给定类别数量,是无监督学习的一种。目前聚类广泛应用于统计学、生物学、数据库技术和市场营销等领域,相应的算法也非常的多。
K-Means属于无监督学习,最大的特别和优势在于模型的建立不需要训练数据。在日常工作中,很多情况下没有办法事先获取到有效的训练数据,这时采用K-Means是一个不错的选择。但K-Means需要预先设置有多少个簇类(K值),这对于像计算某省份全部电信用户的交往圈这样的场景就完全的没办法用K-Means进行。对于可以确定K值不会太大但不明确精确的K值的场景,可以进行迭代运算,然后找出cost最小时所对应的K值,这个值往往能较好的描述有多少个簇类。
2、基本思想及工作原理
基本思想
k-means 算法是根据给定的 n 个数据对象的数据集,构建 k 个划分聚类的方法,每个划分聚类即为一个簇。该方法将数据划分为 n 个簇,每个簇至少有一个数据对象,每个数据对象必须属于而且只能属于一个簇。同时要满足同一簇中的数据对象相似度高,不同簇中的数据对象相似度较小。聚类相似度是利用各簇中对象的均值来进行计算的。
k-means 算法的处理流程如下。首先,随机地选择 k 个数据对象,每个数据对象代表一个簇中心,即选择 k 个初始中心;对剩余的每个对象,根据其与各簇中心的相似度(距离),将它赋给与其最相似的簇中心对应的簇;然后重新计算每个簇中所有对象的平均值,作为新的簇中心。
不断重复以上这个过程,直到准则函数收敛,也就是簇中心不发生明显的变化。通常采用均方差作为准则函数,即最小化每个点到最近簇中心的距离的平方和。
新的簇中心计算方法是计算该簇中所有对象的平均值,也就是分别对所有对象的各个维度的值求平均值,从而得到簇的中心点。例如,一个簇包括以下 3 个数据对象 {(6,4,8),(8,2,2),(4,6,2)},则这个簇的中心点就是 ((6+8+4)/3,(4+2+6)/3,(8+2+2)/3)=(6,4,4)。
k-means 算法使用距离来描述两个数据对象之间的相似度。距离函数有明式距离、欧氏距离、马式距离和兰氏距离,最常用的是欧氏距离。
k-means 算法是当准则函数达到最优或者达到最大的迭代次数时即可终止。当采用欧氏距离时,准则函数一般为最小化数据对象到其簇中心的距离的平方和,即
其中,k 是簇的个数,Ci是第 i 个簇的中心点,dist(Ci,x)为 X 到Ci的距离。
工作原理
1.从数据集合D中随机选择k个对象作为初始簇中心。
2.根据簇的中心值,把数据集合中的n个对象全部分给最“相似”的簇(“相似”根据距离长短来判断)。
3.根据簇的中心值,重新计算每个簇的中心值。
4.计算准则函数。
5.若准则函数满足阈值则退出,否则返回第二步继续。
3、简单案例分析
数据:玩家信息(月)
玩家(ID) | 游戏时间(小时) | 充值金额(元) |
---|---|---|
1 | 60 | 55 |
2 | 90 | 86 |
3 | 30 | 22 |
4 | 15 | 11 |
5 | 288 | 300 |
6 | 223 | 200 |
7 | 0 | 0 |
8 | 14 | 5 |
9 | 320 | 280 |
10 | 65 | 55 |
11 | 13 | 0 |
12 | 10 | 18 |
13 | 115 | 108 |
14 | 3 | 0 |
15 | 52 | 40 |
16 | 62 | 76 |
17 | 73 | 80 |
18 | 45 | 30 |
19 | 1 | 0 |
20 | 180 | 166 |
数据抽象为如下,含义为 游戏时间(小时),充值金额(元)
把玩家分为3类:
1.优质用户(高时长,高消费)
2.普通玩家(在线时长中等,消费中等)
3.不活跃用户 (在线时间短,消费低)
流程图
测试代码
import org.apache.spark.mllib.clustering.KMeans
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.{SparkConf, SparkContext}
object KMeansTest {
def main(args: Array[String]) {
val conf = new SparkConf()
val sc = new SparkContext(conf)
val data =sc.textFile(args(0))
val parsedData =data.map(s => Vectors.dense(s.split(' ').map(_.trim.toDouble))).cache()
//设置簇的个数为3
val numClusters =3
//迭代20次
val numIterations= 20
//运行10次,选出最优解
val runs=10
val clusters =KMeans.train(parsedData, numClusters, numIterations,runs)
// Evaluateclustering by computing Within Set Sum of Squared Errors
val WSSSE = clusters.computeCost(parsedData)
println("WithinSet Sum of Squared Errors = " + WSSSE)
val a21 =clusters.predict(Vectors.dense(57.0,30.0))
val a22 =clusters.predict(Vectors.dense(0.0,0.0))
//打印出中心点
println("Clustercenters:");
for (center <-clusters.clusterCenters) {
println(" "+ center)
}
//打印出测试数据属于哪个簇
println(parsedData.map(v=> v.toString() + " belong to cluster :" +clusters.predict(v)).collect().mkString("\n"))
println("预测第21个用户的归类为-->"+a21)
println("预测第22个用户的归类为-->"+a22)
}
}
输出结果说明
可以明显的看到:
1类用户为优质用户
2类用户为普通用户
3类用户为不活跃用户
21个用户的数据为(57,30)
22个的用户数据为(0,0)
分类是正确的
三个簇的聚集中心
4、优缺点
k-means 聚类算法是一种经典算法,该算法简单高效,易于理解和实现;算法的时间复杂度低,为O(tkm),其中,r 为迭代次数,k 为簇的数目,m 为记录数,n 为维数,并且 t<<m k<<n。
k-means 算法也有许多不足的地方。
- 需要人为事先确定簇的个数,k 的选择往往会是一个比较困难的问题。
- 对初始值的设置很敏感,算法的结果与初始值的选择有关。
- 对噪声和异常数据非常敏感。如果某个异常值具有很大的数值,则会严重影响数据分布。
- 不能解决非凸形状的数据分布聚类问题。
- 主要用于发现圆形或者球形簇,不能识别非球形的簇。
二、Spark MLlib实现K-means算法
1、相关参数和构造方法
Spark MLlib 中的 k-means 算法的实现类 KMeans 具有以下参数。
private var k: int,
private var maxiterations: Int,
private var runs: Int,
private var initializationMode String
private var initializationStep: Int,
private var epsilon: Double,
private var seed: Long) extends: Serializable with Logging
参数的含义解释
名称 | 说明 |
---|---|
k | 表示期望的聚类的个数。 |
maxIterations | 表示方法单次运行的最大迭代次数。 |
runs | 表示算法被运行的次数。k-means 算法不保证能返回全局最优的聚类结果,所以在目标数据集上多次跑 k-means 算法,有助于返回最佳聚类结果。 |
initializationMode | 表示初始聚类中心点的选择方式,目前支持随机选择或者 K_MEANS_PARALLEL 方式,默认是 K_MEANS_PARALLEL。 |
initializationsteps | 表示 K_MEANS_PARALLEL 方法中的步数。 |
epsilon | 表示 k-means 算法迭代收敛的阈值。 |
seed | 表示集群初始化时的随机种子。 |
MLlib 的 k-means 构造函数
使用默认值构造 MLlib 的 k-means 实例的接口如下。
{k: 2,maxIterations: 20,runs: 1, initializationMode: KMeans.K_MEANS_PARALLEL,InitializationSteps: 5,epsilon: le-4,seed:random}。
2、MLlib 中的 k-means 训练函数
MLlib 中的 k-means 训练函数 KMeans.train 方法有很多重载方法,这里以参数最全的一个 来进行说明。KMeans.train 方法如下。
def train(
data:RDD[Vector],
k:Int
maxIterations:Int
runs:Int
initializationMode: String,
seed: Long): KMeansModel = {
new KMeans().setK(k) -
.setMaxIterations(maxIterations)
.setRuns(runs)
.setInitializatinMode(initializationMode)
.setSeed(seed)
.run(data)
}
)
方法中各个参数的含义与构造函数相同,这里不再重复。
3、MLlib 中的 k-means 的预测函数
MLlib 中的 k-means 的预测函数 KMeansModel.predict 方法接收不同格式的数据输入参数,可以是向量或者 RDD,返回的是输入参数所属的聚类的索引号。KMeansModel.predict 方法的 API 如下。
def predict(point:Vector):Int
def predict(points:RDD[Vector]):RDD[int]
第一种预测方法只能接收一个点,并返回其所在的簇的索引值;第二个预测方法可以接收一组点,并把每个点所在簇的值以 RDD 方式返回。
三、Spark ml实现k-means实例
Mysql数据库的表:
用户表:customs
商品表:goods
订单表:orders
订单详情表:ordertIems
实现对已有数据进行用户分组召回:
1、数据库连接的方法
连接数据库中表的方法(表名由参数传入):
def readMysql(spark:SparkSession,tableName:String) = {
val map:Map[String,String] = Map[String,String](
elems = "url"->"jdbc:mysql://192.168.109.138:3306/myshops",
"driver"->"com.mysql.jdbc.Driver",
"user"->"root",
"password"->"root",
"dbtable"->tableName
)
spark.read.format("jdbc").options(map).load()
}
2、自定义函数
业务处理需要用到的一些自定义函数
val func_membership = udf{
(score:Int)=> {
score match {
case i if i<100 => 1
case i if i<500 => 2
case i if i<1000 => 3
case _ => 4
}
}
}
val func_bir = udf {
(idno: String, now: String) => {
var year = idno.substring(6, 10).toInt
var month = idno.substring(10, 12).toInt
val day = idno.substring(12, 14).toInt
val dts = now.split("-")
val nowYear = dts(0).toInt
val nowMonth = dts(1).toInt
val nowDay = dts(2).toInt
if (nowMonth > month) {
nowYear - year
} else if (nowMonth < month) {
nowYear - 1 - year
} else {
if (nowDay >= day) {
nowYear - year
} else {
nowYear - 1 - year
}
}
}
}
val func_claAge = udf{
(bir:Int)=>{
bir match {
case i if i<10 => 1
case i if i<18 => 2
case i if i<23 => 3
case i if i<35 => 4
case i if i<50 => 5
case i if i<70 => 6
case _ => 7
}
}
}
val func_userScore = udf{
(score:Int)=>{
score match {
case i if i<100 => 1
case i if i<500 => 2
case _ => 3
}
}
}
val func_logcount = udf{
(score:Int)=>{
score match {
case i if i<500 => 1
case _ => 2
}
}
}
3、数据清洗
分析表中的数据,将和分组召回有关的数据列分析出来,创建表:
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder().appName("db").master("local[*]").getOrCreate()
//用户表
/**
* 用户表中需要的列和相关的数据清洗
* 用户id cust——id
* 公司 company(字符串转数字)
* 省份 province_id
* 市city_id
* 县区 district_id
* 会员等级 membership_level(分段处理)商业业务(0-100 1,100-500 2,500-1000 3 ,1000- 4)
* 创建日期 create_at(max-min=day)
* 最后登录日期 last_login_time(max-min=day)
* 存活状态 active(过滤) 0已销户 1存在
* 身份证号 idno -->先计算再分age(0-10,1 11-18 2,19-23 3,24-35 4,34-50 5,51-70 6,71- 7)
* 积分 biz_point(0-100 1,100-500 2,500 3)
* 性别 sex 0 1(不存在,通过分析身份证号得出)
* 婚姻状态 marital_status 0未婚 1已婚
* 学历 education_id
* 登录次数 login_count(0-500 0,500- 1)
* 行业 vocation
* 职位 post
*/
val userTable = readMysql(spark,"customs")
//过滤出存活的用户
.filter("active!=0")
.select("cust_id","company","province_id","city_id","district_id"
,"membership_level","create_at","last_login_time","active","idno","biz_point"
,"sex","marital_status","education_id","login_count","vocation","post")
//商品表
val goodTable = readMysql(spark,"goods").select("good_id","price")
//订单表
val orderTable = readMysql(spark,"orders").select("ord_id","cust_id")
//订单明细表
val ordertailTable = readMysql(spark,"orderItems").select("ord_id","good_id","buy_num")
//先将公司名 通过StringIndexer转为数字
val compIndex = new StringIndexer().setInputCol("company").setOutputCol("compId")
//使用一个自定义UDF函数
4、业务处理
相关业务处理:
//先将公司名 通过StringIndexer转为数字
val compIndex = new StringIndexer().setInputCol("company").setOutputCol("compId")
//导入隐式类
import spark.implicits._
//计算每个用户购买的次数
val tmp_bc = orderTable.groupBy("cust_id").agg(count($"ord_id").as("buycount"))
//计算每个用户在网站上花费了多少钱
val tmp_pay = orderTable.join(ordertailTable,Seq("ord_id"),"inner")
.join(goodTable,Seq("good_id"),"inner")
.groupBy("cust_id")
.agg(sum($"buy_num"*$"price").as("pay"))
val df = compIndex.fit(userTable).transform(userTable)
//会员积分等级
.withColumn("mslevel", func_membership($"membership_level"))
//将字符串类型的时间转化为时间格式
.withColumn("reg_date", datediff($"create_at", min($"create_at").over()))
.withColumn("last_login_date", datediff($"last_login_time", min($"last_login_time").over()))
//通过身份证号算出年龄
.withColumn("bir", func_claAge(func_bir($"idno", lit(current_date()))))
//用户等级
.withColumn("user_score", func_userScore($"biz_point"))
//用户登录次数
.withColumn("logcount", func_logcount($"login_count"))
//连接两张表
.join(tmp_bc, Seq("cust_id"), "left")
.join(tmp_pay, Seq("cust_id"), "left")
//将空值装换为0
.na.fill(0)
/删除被清洗过的,无意义的列
.drop("company", "membership_level"
, "create_at", "idno", "biz_point", "login_count","last_login_time")
//将所有列的类型强行转为double类型
//方便后面的绘图和k-means算法的使用
val columns = df.columns.map(f => col(f).cast(DoubleType))
val num_fmt = df.select(columns: _*)
//将除了第一列外的所有列都组装成一个向量列
val va= new VectorAssembler().setInputCols(
Array("cust_id","province_id","city_id","district_id"
,"active","sex","marital_status","education_id"
,"vocation","post","compId","mslevel","reg_date"
,"last_login_date","bir","user_score","logcount"
,"buycount","pay")).setOutputCol("orign_feature")
val ofdf = va.transform(num_fmt).select("cust_id", "orign_feature")
//将原始特征列归一化处理,有利于收敛,感兴趣的可以自行百度
val mmScaler = new MinMaxScaler().setInputCol("orign_feature").setOutputCol("feature")
val resdf = mmScaler.fit(ofdf).transform(ofdf)
.select("cust_id","feature").cache()
//输出显示
// val resdf = mmScaler.fit(ofdf).transform(ofdf).select("cust_id","feature").show(false)
+-------+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|cust_id|feature |
+-------+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|496.0 |[0.004920393631490519,0.48484848484848486,0.4749262536873156,0.4754339511621065,0.5,0.0,0.0,1.0,0.9859719438877755,0.9851940776310524,0.6229854955680902,0.3333333333333333,0.5067753077672482,0.3793429003021148,0.5,1.0,0.0,0.14285714285714285,0.08377923480794836] |
|833.0 |[0.008290663253060245,0.8787878787878788,0.8643067846607669,0.8634892615475139,0.5,0.0,0.0,0.25,0.905811623246493,0.9051620648259304,0.5506647864625303,0.3333333333333333,0.5067753077672482,0.7931457703927492,0.0,1.0,0.0,0.2857142857142857,0.08109126535533952] |
|1088.0 |[0.01084086726938155,0.6363636363636364,0.6371681415929203,0.6363636363636364,0.5,0.0,0.0,0.75,0.8877755511022044,0.8879551820728291,0.21555197421434327,0.6666666666666666,0.48516517580373747,0.7585913897280967,0.5,0.0,0.0,0.0,0.0] |
|1580.0 |[0.01576126090087207,0.5151515151515151,0.5191740412979351,0.5198587819947044,0.5,0.0,0.0,0.5,0.657314629258517,0.657062825130052,0.4879129734085415,0.6666666666666666,0.42848286245682404,0.8620657099697885,0.5,1.0,0.0,0.5714285714285714,0.49068421801860646] |
|1591.0 |[0.015871269701576127,0.21212121212121213,0.20648967551622419,0.20623712856722565,0.5,1.0,0.0,0.75,0.969939879759519,0.9699879951980792,0.7966357775987107,0.0,0.20467629085112035,0.7585913897280967,0.0,1.0,0.0,0.2857142857142857,0.12288458478827007] |
|1829.0 |[0.018251460116809344,0.24242424242424243,0.25663716814159293,0.25831126802000587,0.5,0.0,1.0,0.0,1.0,0.9987995198079231,0.2694399677679291,0.6666666666666666,0.7250022141528651,0.20685422960725075,0.0,1.0,0.0,0.0,0.0] |
|2366.0 |[0.023621889751180094,0.21212121212121213,0.2153392330383481,0.21535745807590467,0.5,0.0,0.0,0.75,0.6372745490981964,0.636654661864746,0.4395648670427075,0.6666666666666666,0.6711540164732973,0.06891993957703928,0.0,0.0,0.0,0.0,0.0] |
|3749.0 |[0.03745299623969917,0.36363636363636365,0.35398230088495575,0.353927625772286,0.5,1.0,1.0,1.0,0.3066132264529058,0.30612244897959184,0.011583400483481063,0.3333333333333333,0.956868302187583,0.20685422960725075,0.5,0.0,0.0,0.14285714285714285,0.16570787382453672] |
|3794.0 |[0.037903032242579404,0.5757575757575758,0.5752212389380531,0.5760517799352751,0.5,1.0,1.0,1.0,0.7755511022044088,0.7759103641456583,0.01863416599516519,0.6666666666666666,0.36923213178637854,0.7585913897280967,0.5,0.0,0.0,0.0,0.0] |
|4101.0 |[0.04097327786222898,0.9393939393939394,0.9174041297935103,0.9173286260664901,0.5,0.0,0.0,1.0,0.36472945891783565,0.3641456582633053,0.4894238517324738,0.3333333333333333,0.8840669559826411,0.7585913897280967,0.5,1.0,0.0,0.0,0.0] |
|4900.0 |[0.04896391711336907,0.5151515151515151,0.504424778761062,0.5045601647543395,0.5,0.0,1.0,0.75,0.280561122244489,0.2801120448179272,0.11714343271555197,0.0,0.5551324063413338,0.5861971299093656,0.0,1.0,0.0,0.0,0.0] |
|4935.0 |[0.04931394511560925,0.45454545454545453,0.4631268436578171,0.46337157987643424,0.5,0.0,0.0,0.25,0.04408817635270541,0.044017607042817125,0.11956083803384368,0.3333333333333333,0.8247276591975911,0.7241314199395771,0.5,0.0,0.0,0.2857142857142857,0.1707153683188049]|
|5300.0 |[0.05296423713897112,0.12121212121212122,0.12979351032448377,0.13033245072080024,0.5,1.0,0.0,0.0,0.14428857715430862,0.1452581032412965,0.28031829170024175,0.6666666666666666,0.3746346647772562,0.6896714501510574,0.5,0.0,0.0,0.0,0.0] |
|5518.0 |[0.05514441155292423,0.7575757575757576,0.7374631268436578,0.736098852603707,0.5,1.0,1.0,0.75,0.9458917835671342,0.9443777511004402,0.4545729250604351,0.3333333333333333,0.7331502966964839,0.8620657099697885,0.5,1.0,0.0,0.0,0.0] |
|5803.0 |[0.05799463957116569,0.9696969696969697,0.9646017699115044,0.9620476610767873,0.5,1.0,0.0,0.75,0.4849699398797595,0.4841936774709884,0.4734085414987913,0.6666666666666666,0.7088831812948366,0.44826283987915405,0.5,0.0,0.0,0.14285714285714285,0.0057946257358421046] |
|6466.0 |[0.06462517001360109,0.7272727272727273,0.7315634218289085,0.7302147690497205,0.5,0.0,1.0,0.75,0.3486973947895792,0.34973989595838334,0.15783642224012892,0.0,0.382782747320875,0.6552114803625377,0.5,0.0,0.0,0.2857142857142857,0.2529119669069929] |
|6620.0 |[0.06616529322345788,0.2727272727272727,0.2890855457227139,0.2900853192115328,0.5,0.0,0.0,0.0,0.8416833667334669,0.8419367747098839,0.20114826752618856,0.6666666666666666,0.471614560269241,0.20685422960725075,0.0,0.0,0.0,0.0,0.0] |
|6654.0 |[0.06650532042563405,0.030303030303030304,0.0471976401179941,0.04766107678729038,0.5,1.0,1.0,0.25,0.9258517034068137,0.9255702280912365,0.3037872683319903,0.3333333333333333,0.7195996811619875,0.3793429003021148,0.5,1.0,0.0,0.14285714285714285,0.0] |
|6658.0 |[0.06654532362589007,0.3939393939393939,0.3893805309734513,0.3901147396293027,0.5,1.0,0.0,0.0,0.7314629258517034,0.7306922769107643,0.9811643835616438,0.6666666666666666,0.4824196262509964,0.7585913897280967,0.0,1.0,0.0,0.2857142857142857,0.1025364048199228] |
|7240.0 |[0.07236578926314105,0.6060606060606061,0.6135693215339233,0.6122388937922919,0.5,0.0,1.0,0.0,0.5050100200400801,0.5046018407362945,0.33340048348106366,0.6666666666666666,0.4500929944203348,0.8620657099697885,0.5,1.0,0.0,0.0,0.0] |
+-------+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
only showing top 20 rows
5、绘图分析质心点数
这里需要用绘图用的方法ChartPanel
重写绘图窗口方法
/*
制作一个图像窗口
*/
class LineGraph(appName:String) extends ApplicationFrame(appName:String) {
def this(appName:String,chartTitle: String,list:ListBuffer[Double]) {
this(appName)
val lineChart = ChartFactory.createLineChart(chartTitle,
"质点", "距离", createDataset(list:ListBuffer[Double]),
PlotOrientation.VERTICAL, true, true, false)
//设置窗口字体
val font = new Font("宋体", Font.BOLD, 20)
lineChart.getTitle.setFont(font)//设置标题字体
lineChart.getLegend.setItemFont(font)//设置标签字体
lineChart.getCategoryPlot.getDomainAxis.setLabelFont(font)//设置x轴字体
lineChart.getCategoryPlot.getRangeAxis.setLabelFont(font)//设置y轴字体
val chartPanel = new ChartPanel(lineChart)
chartPanel.setPreferredSize(new java.awt.Dimension(1600, 1200))
setContentPane(chartPanel)
}
def createDataset(list:ListBuffer[Double]): DefaultCategoryDataset = {
val dataset = new DefaultCategoryDataset();
var point =2;
for (dst<-list){
dataset.addValue(dst,"dist",point+"")
point+=1
}
dataset
}
}
分析质心数
用k-means算法循环计算出质心数和平均距离,绘制的图找出适当的质心数,(找出之后这个方法就没用了,可以注释掉)
这东西非常吃内存,破笔记本差点没撑住= _ =,用了很长时间才运行完
//计算根据不同的质心点计算所有的距离
//记录不同质心点距离的集合
val disLst:ListBuffer[Double] = ListBuffer[Double]()
for(i<-2 to 35){
val kms = new KMeans()
.setFeaturesCol("feature")
.setK(i)
val model = kms.fit(resdf)
disLst.append(model.computeCost(resdf))
}
//调用绘图工具绘图
val chart = new LineGraph("app","Kmeans质心距离",disLst)
chart.pack()
RefineryUtilities.centerFrameOnScreen(chart)
chart.setVisible(true)
kms.fit(resdf).transform(resdf).show(false)
根据图分析出,大致在40左右比较平滑(想看清的可以把35改更大一点,当然得祈祷你计算机撑得住)
6、分组召回
//使用Kmeans算法进行分组
val kms = new KMeans()
.setFeaturesCol("feature")
.setK(40)
val user_group_tab = kms.fit(resdf).transform(resdf).drop("feature")
.withColumnRenamed("prediction","groups")
//获取每组热门商品的前30名
val rank = 30
val wnd = Window.partitionBy("groups").orderBy(desc("group_buy_count"))
user_group_tab
.join(orderTable,Seq("cust_id"),"left")
.join(ordertailTable,Seq("ord_id"),"left")
.na.fill(0)
.groupBy("groups","good_id")
.agg(count("ord_id").as("group_buy_count"))
.withColumn("rank",row_number().over(wnd))
.filter($"rank"<=rank)
.show(false)
spark.close()
}
//0的哪一项是补空时的问题
+------+-------+---------------+----+
|groups|good_id|group_buy_count|rank|
+------+-------+---------------+----+
|31 |0 |694 |1 |
|31 |2974 |3 |2 |
|31 |28528 |3 |3 |
|31 |21243 |3 |4 |
|31 |21610 |3 |5 |
|31 |2246 |3 |6 |
|31 |8016 |3 |7 |
|31 |14912 |3 |8 |
|31 |16435 |3 |9 |
|31 |15391 |3 |10 |
|31 |21871 |3 |11 |
|31 |23981 |3 |12 |
|31 |21777 |3 |13 |
|31 |9289 |3 |14 |
|31 |8839 |3 |15 |
|31 |6205 |3 |16 |
|31 |27560 |2 |17 |
|31 |29696 |2 |18 |
|31 |20833 |2 |19 |
|31 |15421 |2 |20 |
+------+-------+---------------+----+
only showing top 20 rows
参考文章:
https://www.cnblogs.com/wuwuwu/p/6162601.html
http://c.biancheng.net/view/3708.html