import sys #The sys module contains functions related to the Python interpreter and its environment.
from typing import List #The typing library is a library that helps us implement type annotations
import numpy as np #This code requires NumPy
from pyspark.sql import SparkSession #Provides users with a unified entry point to use Spark's features
def parseVector(line: str) -> np.ndarray: #The return type is ndarray(Vector)
return np.array([float(x) for x in line.split(' ')]) #Separate the strings of numbers in spaces and store them separately in np.array
def closestPoint(p: np.ndarray, centers: List[np.ndarray]) -> int: #Calculates the distance from a point to the center of each cluster, returning the index of the nearest cluster
bestIndex = 0 #init the ordinal number of the nearest cluster = 0
closest = float("+inf") #Define a closest node first value is +infinite
for i in range(len(centers)): #Calculates the distance of a data node from each cluster center
tempDist = np.sum((p - centers[i]) ** 2) #Updating the cluster center
if tempDist < closest: #Function, when the cluster center is updated less than the threshold, returns the cluster center coordinates.
closest = tempDist
bestIndex = i
return bestIndex # returns the cluster center index
if __name__ == "__main__": #main function
if len(sys.argv) != 4: #If the number of command line parameters entered is not 4
print("Usage: kmeans <file> <k> <convergeDist>", file=sys.stderr) #Used to redirect error messages to file(screen).The error is indicated in red font
sys.exit(-1) #Exits abnormally from the main thread and outputs -1
print("""WARN: This is a naive implementation of KMeans Clustering and is given
as an example! Please refer to examples/src/main/python/ml/kmeans_example.py for an
example on how to use ML's KMeans implementation.""", file=sys.stderr) # The warn is indicated in red font
#builder create a SparkSession called PythonKMeans(appName)
spark = SparkSession\
.builder\
.appName("PythonKMeans")\
.getOrCreate() #Get it if you have it, create it if you don't
lines = spark.read.text(sys.argv[1]).rdd.map(lambda r: r[0]) #Read the second parameter(file) and use r[0](row) as a result of map corresponding to the primary color in rdd
data = lines.map(parseVector).cache() #Each row in the text is a node, and each node is converted to Vector
K = int(sys.argv[2]) #The number of K(clusters) is the third argument to the command line
convergeDist = float(sys.argv[3]) #Sets the upper limit on the sum of the distances between the old and new cluster centers when the iteration is stopped
#K cluster centers are randomly selected, and the parameter False is set to not put back after taking out the sample
kPoints = data.takeSample(False, K, 1)
tempDist = 1.0 #The sum of each iteration, the sum of the distance between the center of the old and new clusters, is initially set to 1
while tempDist > convergeDist: #Start the iteration and stop the iteration when the temp_dist is less than or equal to the converge_dist
closest = data.map( #The nearest cluster center point of each point is calculated by the map operation, and an rdd is returned, and the elements in the rdd are
lambda p: (closestPoint(p, kPoints), (p, 1))) #(closestPoint(p,1)) tuples. The role of 1 in (p,1) is to facilitate later counting the number of points in a cluster
pointStats = closest.reduceByKey( #Use the reduceByKey operation to add x and y of the points in each cluster, respectively.
lambda p1_c1, p2_c2: (p1_c1[0] + p2_c2[0], p1_c1[1] + p2_c2[1])) #where the values of p1_c1 [1] and p2_c2 [1] are 1, respectively, and the effect of adding them is to count the number of points or members in a cluster
newPoints = pointStats.map( #Computes the new center of each cluster, which is to divide the values of x and y that precede each cluster by the number of points in each cluster
lambda st: (st[0], st[1][0] / st[1][1])).collect() #where st[0] is the id of the cluster, st[1][0] is the summarized x and y, and st[1][1] is the number of cluster points
#Calculates the sum of the Euclidean distances between the new individual cluster centers and the old cluster centers
tempDist = sum(np.sum((kPoints[iK] - p) ** 2) for (iK, p) in newPoints)
#Updates the center of each cluster with the coordinates of the newly obtained cluster center
for (iK, p) in newPoints:
kPoints[iK] = p
#Prints out the center point coordinates of the resulting cluster
print("Final centers: " + str(kPoints))
#Stop the spark session
spark.stop()
k-means聚类算法(pyspark实现)
于 2022-05-10 11:39:39 首次发布