去年的大数据算法作业, 在spark上用scala实现kmeans.
kmeans本身是一个直观的算法, 乞丐版的GMM
一句话解释:选初始k个点做类中心, 然后其他的小弟(样本)就开始根据就近原则拜山头, 拜完山头然后大家重新在各自山头里面选最能代表他们的大佬(中心点), 于是就一直选, 后来直到大佬再选下去也不会变就停止.
def distance(p:Vector[Double], q:Vector[Double]) : Double = {
math.sqrt(p.zip(q).map(pair => math.pow((pair._1 - pair._2), 2)).reduce(_+_))
}
def clostestpoint(q: Vector[Double], candidates: Array[Vector[Double]]): Vector[Double] = {
val index = candidates.map(p=>distance(q,p)).zipWithIndex.minBy(_._1)._2
candidates(index)
}
def add_vec(v1: Vector[Double], v2: Vector[Double]): Vector[Double] = {
v1.zip(v2).map(pair => pair._1+pair._2)
}
def average(cluster: Iterable[Vector[Double]]): Vector[Double] = {
val size = cluster.size
val meanSum = cluster.reduce(add_vec(_,_))
meanSum.map(x => x/size)
}
def randomInitPoints(k: Int, data:List[Vector[Double]]):Array[Vector[Double]]= {
var existedInitPointSet = scala.collection.mutable.Set[Vector[Double]]()
while(existedInitPointSet.size < k){
val index = (new util.Random).nextInt(arrayLength)
existedInitPointSet.add(data(index))
}
existedInitPointSet.toArray
}
def obeservedInitPoints():Array[Vector[Double]]= {
var returnArray = new Array[Vector[Double]](3)
returnArray(0)=(Vector(3.1,14.1))
returnArray(1)=(Vector(4.1,1.1))
returnArray(2)=(Vector(14.1, 22.1))
returnArray
}
import scala.io.Source;
val lines = Source.fromFile("/home/bitnami/kmeans_data.txt").getLines.toList
//val pointArray = lines.map({x=>x.split("\t").toList})
//val vecArray = pointArray.map(x=>x.map(_.toString).to[Vector])
val data = lines.map( l => Vector.empty ++ l.split('\t').map(_.toDouble))
//print vector
// data.foreach(x=>println("vec:"+x))
//init data points
val k = 3
val iteraTimes = 1000
val arrayLength = data.size
var meanArray = randomInitPoints(k, data)
// var meanArray = obeservedInitPoints()
for(i<-0 to iteraTimes-1){
var clostestPointAssign = data.map(p => clostestpoint(p,meanArray))
for(n<-0 to k-1){
//val tempMeanPoint = clostestPointAssign.zipWithIndex.filter((x, index)=>x==meanArray(n)).map(index)
// clostestPointAssign.foreach(println)
val tempMeanPoint = clostestPointAssign.zipWithIndex.collect{
case (x,index) if x==meanArray(n) => data(index)
}
// tempMeanPoint.foreach(println)
meanArray(n) = average(tempMeanPoint)
}
}
println("========Final mean out put==========")
meanArray.foreach(println)
println("========Final mean out put==========")