1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73
|
def kmeans(x, k, max_it=32): r""" KMeans algorithm for clustering the sentences by length.
Args: x (list[int]): The list of sentence lengths. k (int): The number of clusters. This is an approximate value. The final number of clusters can be less or equal to `k`. max_it (int): Maximum number of iterations. If centroids does not converge after several iterations, the algorithm will be early stopped.
Returns: list[float], list[list[int]]: The first list contains average lengths of sentences in each cluster. The second is the list of clusters holding indices of data points.
Examples: >>> x = torch.randint(10,20,(10,)).tolist() >>> x [15, 10, 17, 11, 18, 13, 17, 19, 18, 14] >>> centroids, clusters = kmeans(x, 3) >>> centroids [10.5, 14.0, 17.799999237060547] >>> clusters [[1, 3], [0, 5, 9], [2, 4, 6, 7, 8]] """
x, k = torch.tensor(x, dtype=torch.float), min(len(x), k) d = x.unique() c = d[torch.randperm(len(d))[:k]] dists, y = torch.abs_(x.unsqueeze(-1) - c).min(-1)
for _ in range(max_it): mask = torch.arange(k).unsqueeze(-1).eq(y) none = torch.where(~mask.any(-1))[0].tolist() while len(none) > 0: for i in none: b = torch.where(mask[mask.sum(-1).argmax()])[0] f = dists[b].argmax() y[b[f]] = i mask = torch.arange(k).unsqueeze(-1).eq(y) none = torch.where(~mask.any(-1))[0].tolist() c, old = (x * mask).sum(-1) / mask.sum(-1), c dists, y = torch.abs_(x.unsqueeze(-1) - c).min(-1) if c.equal(old): break assigned = y.unique().tolist() centroids = c[assigned].tolist() clusters = [torch.where(y.eq(i))[0].tolist() for i in assigned]
return centroids, clusters
|