Faiss之聚类源码解析

目录

1. 初始化

2. 实例化Flat索引

3. 训练

3.1 数据校验

3.2 初始化并迭代更新聚类中心


首先,从宏观的角度来一张Faiss聚类的流程,如下图:

整体的代码如下:

float kmeans_clustering (size_t d, size_t n, size_t k,
                         const float *x,
                         float *centroids)
{
    Clustering clus (d, k); // 初始化参数
    clus.verbose = d * n * k > (1L << 30);
    // display logs if > 1Gflop per iteration
    IndexFlatL2 index (d); // 实例化flat索引
    clus.train (n, x, index); // 训练,也就是聚类
    memcpy(centroids, clus.centroids.data(), sizeof(*centroids) * d * k);
    return clus.iteration_stats.back().obj; //返回各样本点到各自聚类中心的距离之和
}

  大致分为三大步骤:初始化,实例化量化索引,以及训练

1. 初始化

聚类参数较多,罗列出几个比较重要的参数

/** Class for the clustering parameters. Can be passed to the
 * constructor of the Clustering object.
 */
struct ClusteringParameters {
    int niter;          ///< clustering iterations
    int nredo;          ///< redo clustering this many times and keep best

    bool spherical;     ///< do we want normalized centroids?
    bool int_centroids; ///< round centroids coordinates to integer
    
    int min_points_per_centroid; ///< otherwise you get a warning
    int max_points_per_centroid;  ///< to limit size of dataset
};

niter:每一次聚类需要迭代的次数

nredo: 训练的时候,聚类的次数

spherical: 是否需要归一化

update_index:重复训练的时候是否需要更新索引

min_points_per_centroid: 每一簇最小样本数,低于这个数会warning,但是还是会继续

max_points_per_centroid: 每一簇最大样本数,超过这个数会采样,后文会分析。

聚类是ClusteringParameters的子类,多了两个需要指定的参数d和k,和一个存储聚类中心的vector(centroids):

struct Clustering: ClusteringParameters {
    typedef Index::idx_t idx_t;
    size_t d;              ///< dimension of the vectors
    size_t k;              ///< nb of centroids

    /** centroids (k * d)
     * if centroids are set on input to train, they will be used as initialization
     */
    std::vector<float> centroids;
}

d 为参与聚类的向量维度,k为聚类中心的个数, centroids为聚类中心向量构成的一维vector。

2. 实例化Flat索引

Flat索引是faiss中最简单的索引,其实可以看成待搜索数据组成的一个list。用该索引搜索最近的top n个向量时,采用堆的方式来搜索。而flat索引在聚类中的作用为:聚类过程中,需要找到离每个样本点最近的聚类中心,而一般聚类中心不会太多,使用最简单的flat索引就已足够。所以这里利用一个flat索引存储聚类中心,通过样本在flat索引中找寻top1 的点,即为该样本最近的聚类中心。

3. 训练

参数准备好了,接下来,进入本文的重点,到底如何聚类呢。

void Clustering::train_encoded (idx_t nx, const uint8_t *x_in,
                                const Index * codec, Index & index,
                                const float *weights) 

3.1 数据校验

首先,必须满足以下条件才能继续:

    1. 训练的数据集的数量不得低于聚类中心数

    2. Flat的维数必须与训练数据集的维数一致

然后,如果训练数据集的总量小于k*min_points_per_centroid, 则发起warning,不过聚类还会继续。如果总量大于k*max_points_per_centroid,则采样k*max_points_per_centroid个数据来替代原有的训练数据集。

if (nx > k * max_points_per_centroid) {
        uint8_t *x_new;
        float *weights_new;
        nx = subsample_training_set (*this, nx, x, line_size, weights,
                                &x_new, &weights_new);
        del1.reset (x_new); x = x_new;
        del3.reset (weights_new); weights = weights_new;
    } else if (nx < k * min_points_per_centroid) {
        fprintf (stderr,
                 "WARNING clustering %" PRId64 " points to %zd centroids: "
                 "please provide at least %" PRId64 " training points\n",
                 nx, k, idx_t(k) * min_points_per_centroid);
    }

这里有一个corner case, 当训练数据集的数量等于聚类中心的数量时,那么直接每一个向量当成一个聚类中心,原样返回。

3.2 初始化并迭代更新聚类中心

外层有两个大循环:

for (int redo = 0; redo < nredo; redo++) { // nredo 前文介绍过,聚类的次数

        ...... // 每一次聚类之前,初始化聚类中心
        for (int i = 0; i < niter; i++) {
            ...... // 迭代更新聚类中心
        }
    }

1)初始化聚类中心

    rand_perm (perm.data(), nx, seed + 1 + redo * 15486557L);
     for (int i = n_input_centroids; i < k ; i++) {
           memcpy (&centroids[i * d], x + perm[i] * line_size, line_size);
     }
    index.add (k, centroids.data());

随机打散, 取k个点,拷贝至聚类中心数组,并添加到flat index。
2)迭代更新聚类中心

        index.search (nx, reinterpret_cast<const float *>(x), 1, dis.get(), assign.get());
       // accumulate error
       err = 0;
       for (int j = 0; j < nx; j++) {
              err += dis[j];
        }

通过flat index查询距离每个向量最近的聚类中心点,并将聚类中心id存在assign数组中,与聚类中心的距离存在dis数组中。比如序号为3的向量最近的聚类中心id为1,则assigin[3] = 1,dis[3]则是两者的距离。err是总偏差,保存所有的样本点到最近的聚类中心距离之和,用于最后寻找最优的聚类中心结果。

// 计算每一簇的中心,并更新 
 compute_centroids (  d, k, nx, k_frozen, x, codec, assign.get(), weights, hassign.data(), centroids.data() ); 

// 对于新生成的簇,有的簇可能没有向量,取一个向量较多的簇分割成两个小簇
int nsplit = split_clusters (  d, k, nx, k_frozen, hassign.data(), centroids.data() );

 由于这两个函数比较核心,这里删除部分代码后贴出源代码并增加相应的注释。

compute_centroids的作用是计算每个簇的所有向量的总和,以及向量个数,得到每一簇的均值,即为新的聚类中心。为了提升计算的速度,将簇按照线程数分段,每一个线程计算对应分段的簇。举个例子,现在有10个线程,100个簇,那么0号线程计算0-9号簇,1号线程计算10-19号簇,以此类推。

split_cluster的作用是找出数量为0的簇,并找出一个较大的簇,将其平均分成两份,并更新两个小簇对应的中心点。

void compute_centroids (size_t d, size_t k, size_t n,
                       size_t k_frozen,
                       const uint8_t * x, const Index *codec,
                       const int64_t * assign,
                       const float * weights,
                       float * hassign,
                       float * centroids)
{
    k -= k_frozen;
    centroids += k_frozen * d;

    memset (centroids, 0, sizeof(*centroids) * d * k); // 清零

    size_t line_size = codec ? codec->sa_code_size() : d * sizeof (float); //每一个向量的size

#pragma omp parallel  //并行计算,将centroids分段,每个线程只计算对应的分段
    {
        int nt = omp_get_num_threads(); //获取总的线程数
        int rank = omp_get_thread_num(); //获取当前线程id

        // this thread is taking care of centroids c0:c1
        size_t c0 = (k * rank) / nt; 
        size_t c1 = (k * (rank + 1)) / nt;
        std::vector<float> decode_buffer (d);

        for (size_t i = 0; i < n; i++) {
            int64_t ci = assign[i]; //获取离第i个向量最近的聚类中心id(ci)
            if (ci >= c0 && ci < c1)  { //只计算c0-c1的
                float * c = centroids + ci * d;
                const float * xi;
                xi = reinterpret_cast<const float*>(x + i * line_size);
                
                    hassign[ci] += 1.0; //ci簇的向量数加1
                    for (size_t j = 0; j < d; j++) {
                        c[j] += xi[j]; //获取ci簇每个维度上的总和
                    }
            }
        }

    }

#pragma omp parallel for //并行计算每个簇的均值,得到新的聚类中心
    for (idx_t ci = 0; ci < k; ci++) {
        if (hassign[ci] == 0) {
            continue;
        }
        float norm = 1 / hassign[ci];
        float * c = centroids + ci * d;
        for (size_t j = 0; j < d; j++) {
            c[j] *= norm;
        }
    }
}

int split_clusters (size_t d, size_t k, size_t n,
                    size_t k_frozen,
                    float * hassign,
                    float * centroids)
{
    /* Take care of void clusters */
    size_t nsplit = 0;
    RandomGenerator rng (1234);
    for (size_t ci = 0; ci < k; ci++) {
        if (hassign[ci] == 0) { /* 数量为0的簇,需要找一个大粗分割 */
            size_t cj;
            for (cj = 0; 1; cj = (cj + 1) % k) {
                /* probability to pick this cluster for split */
                float p = (hassign[cj] - 1.0) / (float) (n - k);
                float r = rng.rand_float ();
                if (r < p) {
                    break; /* 找到一个分割的大簇 */
                }
            }
            //将大簇中心copy给小簇
            memcpy (centroids+ci*d, centroids+cj*d, sizeof(*centroids) * d);

            /* 通过对两个相同的中心添加反向扰动,从而分成两个中心 */
            for (size_t j = 0; j < d; j++) {
                if (j % 2 == 0) {
                    centroids[ci * d + j] *= 1 + EPS;
                    centroids[cj * d + j] *= 1 - EPS;
                } else {
                    centroids[ci * d + j] *= 1 - EPS;
                    centroids[cj * d + j] *= 1 + EPS;
                }
            }

            /* 更新对应簇的数量 */
            hassign[ci] = hassign[cj] / 2;
            hassign[cj] -= hassign[ci];
            nsplit++;
        }
    }
    return nsplit;
}

通过这两个函数,可以得到更新后的中心,再将flat index 清零,并将新的中心点添加到index,作为下一次迭代的搜索索引。

这样知道迭代次数结束,选出距离偏差err最小的一次训练结果作为聚类的最后结果。

if (nredo > 1) {
        centroids = best_centroids;
        iteration_stats = best_obj;
        index.reset();
        index.add(k, best_centroids.data());
    }

至此,Faiss中聚类的源码剖析结束~ 

总结Faiss的聚类,大概以下几点:

如果训练数据集过大,则采样部分来做训练。

 faiss中聚类的停止条件是达到了指定的迭代次数。(距离收敛应该也可以作为停止条件,但是faiss中采用的是迭代次数)

在更新聚类中心的时候,采用并行计算各簇均值。

对于没有数据的空簇,找一个较大的簇平分成两份。

  • 6
    点赞
  • 15
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

jevenabc

请我喝杯咖啡吧~

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值