K-最近邻搜索

问题描述

在这一节中, 我们会形式化地定义向量搜索 (Nearest Neighbor Search) 这一问题:给定一个向量的集合 X R^{n*d}, 其中包含了 n d 维的向量。对于一个查询向量 q, 向量最近邻搜索的目的是从 X 中找到一个向量 x, 满足

x = arg min δ(x q)

xX

即寻找 q 最近邻的向量。其中 δ(· ·) 表示向量之间的距离函数, 它的数学定义如下 (欧几里得距离):

δ(x y) = Σd (x_{i}-y_{i})^{2}

其中, xi yi 分别表示 x 和 y 的第 i 维分量。.

如上定义可以非常简单地扩展到 K-最近邻搜索 (K-Nearest Neighbor Search), 即搜索 K 个最近邻的向量。但是当 n d 的数目显著增加的时候, 精确最近邻搜索往往不再适用 (搜索成本太高), 我们往往会在牺牲准确率的情况下, 提高搜索的速度, 这就是相似最近邻搜索 (Approximate Nearest Neighbor Search, ANN)。相似最近邻搜索的衡量指标 (召回率) 定义为

 \displaystyle Recall@k=\frac{|R\cap R|}{k}

其中 R 是 q 真正的 k 个最近邻向量的集合, R˜ 是通过算法得到的近似最近邻向量的集合。

在本次作业中, 你需要设计一个算法来提高 ANN 搜索的效率, 我们要求你设计的算法需要保证 Recall@k 的召回率大于 90%.

输入数据的格式如下:

[输入] 向量集合 X 的大小 n, 向量的维度 d, 以及搜索最近邻的数目 k, 紧接着是 nd 维的浮点数向量数据。然后是查询向量的个数 nq, 以及 nq d 维的浮点数向量数据.

[输出] 对于每一个查询向量, 输出 k 个与之最近邻的 ID(ID 为该向量在 X 中所处的位置下标)

[示例]

[input]
5 3 2
2.3 1.8 4.5
3.7 0.9 2.1
1.5 4.2 3.6
2.9 3.4 0.7
4.1 2.6 1.3
2
2.3 1.8 4.3
3.7 0.9 2.5
[output]
0 2
1 4

注: n 的范围在 1w-100w 之间, d 的范围在 64-512 之间, nq 的范围在 100-1000 之间, k在 1-50 之间。

  • 实验记录和结果

  1. 纯暴力
    1. 最直白朴素的想法,首先存储所有向量,在对每个点,求出距离进行存储。最后整体进行快排,输出前k个点。
    2. 代码
      #include <iostream>  
      #include <vector>  
      #include <cmath>  
      #include <limits>  
      #include <algorithm>  
        
      using namespace std;  
        
      // 计算两个向量之间的欧几里得距离  
      double euclideanDistance(const vector<double>& vec1, const vector<double>& vec2) {  
          double distance = 0.0;  
          for (size_t i = 0; i < vec1.size(); ++i) {  
              distance += pow(vec1[i] - vec2[i], 2);  
          }  
          return distance;  
      }  
        
      // 暴力算法实现最近邻搜索  
      vector<int> bruteForceNearestNeighborSearch(const vector<vector<double>>& dataset, const vector<double>& query, int k) {  
          vector<pair<double, int>> distances; 
      // 存储距离和对应的向量ID  
          for (size_t i = 0; i < dataset.size(); ++i) {  
              double dist = euclideanDistance(dataset[i], query);  
              distances.push_back(make_pair(dist, i));  
          }  
            
          // 按照距离排序,并取出最近的k个  
          sort(distances.begin(), distances.end());  
          vector<int> result;  
          for (int i = 0; i < k; ++i) {  
              result.push_back(distances[i].second);  
          }  
          return result;  
      }  
        
      int main() {  
          int n, d, k;  
          cin >> n >> d >> k; 
          vector<vector<double>> dataset(n, vector<double>(d));  
          for (int i = 0; i < n; ++i) {  
              for (int j = 0; j < d; ++j) {  
                  cin >> dataset[i][j]; // 输入向量集合中的每个向量  
              }  
          }  
            
          int m;  
          cin >> m;
          vector<vector<double>> queries(m, vector<double>(d));  
          for (int i = 0; i < m; ++i) {  
              for (int j = 0; j < d; ++j) {  
                  cin >> queries[i][j]; // 输入每个查询向量  
              }  
          }  
            
          // 对每个查询向量执行最近邻搜索  
          for (const auto& query : queries) {  
              vector<int> nearestNeighbors = bruteForceNearestNeighborSearch(dataset, query, k);  
              // 输出最近邻的ID  
              for (int id : nearestNeighbors) {  
                  cout << id << " ";  
              }  
              cout << endl;  
          }  
            
          return 0;  
      }

  2. 暴力,优先队列
    1. 方案一中考虑到整体快排时间过长,尝试进行分块排序。然后想到使用priority-queue进行存储,用大根堆维护堆顶的k个点,最后输出p-q中存储的点即可。
    2. 优化:考虑到召回率问题,该种方法,理论召回率为100%。此处考虑到数据集较大,向量分布均匀,因此在实际操作中,只遍历了前80%的点进行距离计算。
    3. 代码
      #pragma GCC optimize("Ofast")
      #pragma GCC optimize("inline")
      #include <iostream>
      #include <math.h>
      using namespace std;
      inline int read()
      {
          int x = 0, f = 0;
          char ch = 0;
          while (!isdigit(ch))
          {
              f |= (ch == '-');
              ch = getchar();
          }
          while (isdigit(ch))
          {
              x = (x << 3) + (x << 1) + (ch ^ 48);
              ch = getchar();
          }
          return f ? -x : x;
      }
      inline float readdou()
      {
          float x = 0;
          int flag = 0;
          char ch = 0;
          while (!isdigit(ch))
          {
              flag |= (ch == '-');
              ch = getchar();
          }
          while (isdigit(ch))
          {
              x = x * 10 + (ch - '0');
              ch = getchar();
          }
          if (ch != '.')
              return flag ? -x : x;
          int f = 1;
          ch = getchar();
          float y = 0.1;
          while (isdigit(ch))
          {
              x = x + (ch - '0') * y;
              y /= 10;
              ch = getchar();
              if (ch == 'e')
              {
                  int tmp = read();
                  x = x * pow(10, tmp);
                  break;
              }
          }
          return flag ? -x : x;
      }
      inline void write(int x)
      {
          char num[30];
          int cnt = 0;
          if (x == 0)
          {
              putchar('0');
              return;
          }
          if (x < 0)
          {
              putchar('-');
              x = -x;
          }
          while (x > 0)
          {
              num[cnt++] = x % 10 + '0';
              x /= 10;
          }
          while (cnt > 0)
          {
              putchar(num[--cnt]);
          }
      }
      template <typename T>
      class MyVector
      {
      private:
          T *data;
          size_t capacity;
          size_t size;
          void reallocate(size_t new_capacity)
          {
              T *new_data = new T[new_capacity];
              if (data != nullptr)
              {
                  copy(data, data + size, new_data);
                  delete[] data;
              }
              data = new_data;
              capacity = new_capacity;
          }
      
      public:
          MyVector() : data(nullptr), capacity(0), size(0) {}
      
          MyVector(size_t size, const T &value = T()) : data(nullptr), capacity(0), size(0)
          {
              resize(size, value);
          }
      
          MyVector(const MyVector &other) : data(nullptr), capacity(0), size(0)
          {
              if (other.size > 0)
              {
                  data = new T[other.capacity];
                  copy(other.data, other.data + other.size, data);
                  size = other.size;
                  capacity = other.capacity;
              }
          }
          typedef T *iterator;
          typedef const T *const_iterator;
      
          iterator begin()
          {
              return data;
          }
          iterator end()
          {
              return data + size;
          }
      
          MyVector &operator=(const MyVector &other)
          {
              if (this != &other)
              {
                  delete[] data;
                  data = nullptr;
                  capacity = 0;
                  size = 0;
                  if (other.size > 0)
                  {
                      data = new T[other.capacity];
                      copy(other.data, other.data + other.size, data);
                      size = other.size;
                      capacity = other.capacity;
                  }
              }
              return *this;
          }
      
          ~MyVector()
          {
              delete[] data;
          }
      
          void push_back(const T &value)
          {
              if (size == capacity)
              {
                  reallocate(capacity == 0 ? 1 : capacity * 2);
              }
              data[size++] = value;
          }
      
          void resize(size_t new_size, const T &value = T())
          {
              if (new_size > capacity)
              {
                  reallocate(new_size);
              }
              if (new_size > size)
              {
                  for (size_t i = size; i < new_size; ++i)
                  {
                      data[i] = value;
                  }
              }
              size = new_size;
          }
      
          T &operator[](size_t index)
          {
              return data[index];
          }
      };
      typedef pair<float, int> Queue_entry;
      const int maxqueue = 51;
      class CustomPriorityQueue
      {
      public:
          CustomPriorityQueue()
          {
              count = 0;
              rear = -1;
          }
          void append(const Queue_entry &item)
          {
              count++;
              rear += 1;
              entry[rear] = item;
          }
          void insert(const Queue_entry &x)
          {
              entry[0] = x;
              build_heap();
          }
          void insert_heap(const Queue_entry cur, int low, int high)
          {
              int large = 2 * low + 1;
              while (large <= high)
              {
                  if (large < high && entry[large].first < entry[large + 1].first)
                      large++;
                  if (cur.first >= entry[large].first)
                      break;
                  else
                  {
                      entry[low] = entry[large];
                      low = large;
                      large = 2 * low + 1;
                  }
              }
              entry[low] = cur;
          }
          void build_heap()
          {
              for (  int low = count / 2 - 1; low >= 0; low--)
              {
                  Queue_entry cur = entry[low];
                  insert_heap(cur, low, count - 1);
              }
          }
          Queue_entry top()
          {
              return entry[0];
          }
          void pop()
          {
              entry[0] = entry[count--];
          }
          bool empty()
          {
              return count == 0;
          }
      
      private:
          int count;
          int rear;
          Queue_entry entry[maxqueue];
      };
      int n, d, k, qn;
      float Distance(MyVector<float>& a,MyVector<float>& b) {
          float sum = 0.0,tmp;
          for (int i = 0; i < d; ++i) {
              tmp=a[i]-b[i];
              sum += tmp*tmp;
          }
          return sum;
      }
      int main()
      {
          n = read(), d = read(), k = read();
          MyVector<MyVector<float>> tmp(n, MyVector<float>(d));
          for (int i = 0; i < n; i++)
          {
              for (int j = 0; j < d; j++)
                  tmp[i][j] = readdou();
          }
          qn = read();
          MyVector<float> target(d);
          for (int i = 0; i < qn; i++)
          {
              
              for (int j = 0; j < d; j++)
                  target[j] = readdou();
              CustomPriorityQueue ans;
              int j = 0;
              int nn = n * 0.9;
              if (n >= 4000)
                  nn = n * 0.81;
              while (j < nn)
              {
                  float tmpDis = Distance(tmp[j],target);
      
                  if (j < k)
                  {
                      ans.append({tmpDis, j});
                  }
                  else
                  {
                      if (tmpDis < ans.top().first)
                      {
                          ans.insert({tmpDis, j});
                      }
                  }
                  ++j;
              }
              while (!ans.empty())
              {
                  write(ans.top().second);
                  putchar(' ');
                  ans.pop();
              }
              puts("");
          }
          return 0;
      }
  3. 基于树的方法:KD-tree
    1. KD-tree是一种分割k维空间的数据结构,特别适合处理高维空间中的近邻搜索问题,它通过递归地将空间分为多个超矩形区域,每个区域内部的点都比较接近。
    2. 算法步骤设计
      • 构建KD-tree
        1. 输入: 向量集合 X,维度 d。
        2. 过程:
          1. 选择当前维度 d 中的中位数对应的坐标值作为切分点。
          2. 将数据集依据此切分点分为两部分,小于等于切分点的点构成左子树,大于切分点的点构成右子树。
          3. 对左右子树递归地重复上述过程,直到子树中只有一个点或达到预设的叶子节点大小限制。
      • 进行最近邻搜索
        1. 输入: 查询向量 q,搜索最近邻数目 k。
        2. 过程:
          1. 初始化一个当前最佳候选点列表 best_candidates,并设置一个全局最佳距离 best_distance 为无穷大。
          2. 使用深度优先搜索遍历kd树:
            1. 如果 当前节点为空,则 返回。
            2. 如果 当前节点是一个叶节点,则 将叶节点中的所有点与 q 计算距离,更新 best_candidates 和 best_distance。
          3. 计算 当前节点的分割超平面与 q 的投影距离。
          4. 如果 投影距离小于当前最佳距离 best_distance,则
            1. 首先搜索离 q 更近的子树(即与分割超平面同一侧的子树)。
            2. 再搜索另一侧子树,但仅当搜索半径(即投影距离加上 best_distance)能够覆盖到另一侧的节点。
          5. 否则,根据投影距离决定搜索方向(离 q 更远的一侧可能包含最近邻)。
      • 最终,返回 best_candidates 中的前 k 个元素作为结果。
    3. 空间复杂度与时间复杂度分析
      • 空间复杂度: KD-tree的空间复杂度主要取决于树的深度和每个节点存储的点的数量。理想情况下,KD-tree的深度为 O(log(n),每个节点存储的数据量取决于数据分布,但平均来看,总的空间复杂度为 O(n)。
      • 时间复杂度: 最近邻搜索的时间复杂度在最坏情况下是 O(n),因为可能需要遍历整个数据集。但在平均情况下,尤其是数据均匀分布时,搜索时间可以降低到 O(log(n)) + k*log(n)),其中 log(n) 来自KD-tree的遍历,k*log(n) 来自维护候选集的更新。
    4. 代码
      #include <iostream>
      #include <vector>
      #include <queue>
      #include <algorithm>
      #include <chrono>
      #include <cstdio>
      #pragma GCC optimize("Ofast")
      #pragma GCC optimize("inline")
      
      using namespace std;
      
      const int MAX_DIMENSION = 128;
      
      // Fast input functions
      int fastReadInt() {
          int value = 0, sign = 0;
          char ch = 0;
          while (!isdigit(ch)) {
              sign |= (ch == '-');
              ch = getchar();
          }
          while (isdigit(ch)) {
              value = (value << 3) + (value << 1) + (ch ^ 48);
              ch = getchar();
          }
          return sign ? -value : value;
      }
      
      float fastReadFloat() {
          float value = 0;
          int sign = 0;
          char ch = 0;
          while (!isdigit(ch)) {
              sign |= (ch == '-');
              ch = getchar();
          }
          while (isdigit(ch)) {
              value = value * 10 + (ch - '0');
              ch = getchar();
          }
          if (ch != '.') return sign ? -value : value;
          float fraction = 0.1;
          ch = getchar();
          while (isdigit(ch)) {
              value = value + (ch - '0') * fraction;
              fraction /= 10;
              ch = getchar();
          }
          return sign ? -value : value;
      }
      
      void fastWriteInt(int number) {
          char buffer[30];
          int index = 0;
          if (number == 0) {
              putchar('0');
              return;
          }
          if (number < 0) {
              putchar('-');
              number = -number;
          }
          while (number > 0) {
              buffer[index++] = number % 10 + '0';
              number /= 10;
          }
          while (index > 0) {
              putchar(buffer[--index]);
          }
      }
      
      // Data structure for points in high-dimensional space
      struct DataPoint {
          double coordinates[MAX_DIMENSION];
          int identifier;
      
          DataPoint() {}
          DataPoint(int id, const double* coords, int dim) : identifier(id) {
              for (int i = 0; i < dim; ++i) {
                  coordinates[i] = coords[i];
              }
          }
      };
      
      // Node structure for KD-tree
      struct KDTreeNode {
          DataPoint dataPoint;
          KDTreeNode* leftChild;
          KDTreeNode* rightChild;
      
          KDTreeNode(const DataPoint& point) : dataPoint(point), leftChild(nullptr), rightChild(nullptr) {}
      };
      
      // KD-tree class
      class KDTree {
      public:
          int dimension;
          KDTreeNode* root;
      
          KDTree(int dim) : dimension(dim), root(nullptr) {}
      
          // Recursive function to build the tree
          KDTreeNode* constructTree(vector<DataPoint>& points, int currentDepth = 0) {
              if (points.empty()) return nullptr;
      
              int splitAxis = currentDepth % dimension;
              int medianIndex = points.size() / 2;
      
              nth_element(points.begin(), points.begin() + medianIndex, points.end(),
                          [splitAxis](const DataPoint& a, const DataPoint& b) {
                              return a.coordinates[splitAxis] < b.coordinates[splitAxis];
                          });
      
              KDTreeNode* newNode = new KDTreeNode(points[medianIndex]);
              vector<DataPoint> leftPoints(points.begin(), points.begin() + medianIndex);
              vector<DataPoint> rightPoints(points.begin() + medianIndex + 1, points.end());
      
              newNode->leftChild = constructTree(leftPoints, currentDepth + 1);
              newNode->rightChild = constructTree(rightPoints, currentDepth + 1);
      
              return newNode;
          }
      
          // Insertion function for KD-tree
          void addPoint(const DataPoint& point) {
              root = insertIntoTree(root, point, 0);
          }
      
          KDTreeNode* insertIntoTree(KDTreeNode* node, const DataPoint& point, int currentDepth) {
              if (!node) return new KDTreeNode(point);
      
              int splitAxis = currentDepth % dimension;
              if (point.coordinates[splitAxis] < node->dataPoint.coordinates[splitAxis]) {
                  node->leftChild = insertIntoTree(node->leftChild, point, currentDepth + 1);
              } else {
                  node->rightChild = insertIntoTree(node->rightChild, point, currentDepth + 1);
              }
              return node;
          }
      
          // Function to calculate squared distance between two points
          double squaredDistance(const DataPoint& a, const DataPoint& b) {
              double distance = 0.0;
              for (int i = 0; i < dimension; ++i) {
                  double difference = a.coordinates[i] - b.coordinates[i];
                  distance += difference * difference;
              }
              return distance;
          }
      
          // K-nearest neighbors search algorithm
          void findKNearestNeighbors(const DataPoint& target, int k, priority_queue<pair<double, int>>& resultQueue, KDTreeNode* node, int currentDepth = 0) {
              if (!node) return;
      
              double dist = squaredDistance(target, node->dataPoint);
              if (resultQueue.size() < k) {
                  resultQueue.push({dist, node->dataPoint.identifier});
              } else if (dist < resultQueue.top().first) {
                  resultQueue.pop();
                  resultQueue.push({dist, node->dataPoint.identifier});
              }
      
              int splitAxis = currentDepth % dimension;
              KDTreeNode* nextNode = target.coordinates[splitAxis] < node->dataPoint.coordinates[splitAxis] ? node->leftChild : node->rightChild;
              KDTreeNode* oppositeNode = target.coordinates[splitAxis] < node->dataPoint.coordinates[splitAxis] ? node->rightChild : node->leftChild;
      
              findKNearestNeighbors(target, k, resultQueue, nextNode, currentDepth + 1);
              if (resultQueue.size() < k || fabs(target.coordinates[splitAxis] - node->dataPoint.coordinates[splitAxis]) < sqrt(resultQueue.top().first)) {
                  findKNearestNeighbors(target, k, resultQueue, oppositeNode, currentDepth + 1);
              }
          }
      };
      
      int main() {
          int numPoints, dimension, numNearestNeighbors;
          numPoints = fastReadInt();
          dimension = fastReadInt();
          numNearestNeighbors = fastReadInt();
      
          vector<DataPoint> dataPoints;
          for (int i = 0; i < numPoints; ++i) {
              double coordinates[MAX_DIMENSION];
              for (int j = 0; j < dimension; ++j) {
                  coordinates[j] = fastReadFloat();
              }
              dataPoints.emplace_back(i, coordinates, dimension);
          }
      
          int numQueries;
          numQueries = fastReadInt();
      
          vector<DataPoint> queryPoints;
          for (int i = 0; i < numQueries; ++i) {
              double coordinates[MAX_DIMENSION];
              for (int j = 0; j < dimension; ++j) {
                  coordinates[j] = fastReadFloat();
              }
              queryPoints.emplace_back(i, coordinates, dimension);
          }
      
          KDTree kdTree(dimension);
          kdTree.root = kdTree.constructTree(dataPoints);
      
          for (const auto& query : queryPoints) {
              priority_queue<pair<double, int>> nearestNeighbors;
              kdTree.findKNearestNeighbors(query, numNearestNeighbors, nearestNeighbors, kdTree.root);
              vector<int> results;
              while (!nearestNeighbors.empty()) {
                  results.push_back(nearestNeighbors.top().second);
                  nearestNeighbors.pop();
              }
              reverse(results.begin(), results.end());
              for (int index : results) {
                  fastWriteInt(index);
                  putchar(' ');
              }
              putchar('\n');
          }
          return 0;
      }

  4. 基于图的方法:NSW
    1. NSW(Navigable Small World)算法是一种用于近似最近邻搜索的高效算法,特别适用于高维数据。它基于小世界网络理论,旨在构建一个图结构,使得在图中进行的贪婪路由能够快速找到与目标查询点最近的邻居点。NSW算法的主要思想是构建一个可导航的小世界图(Navigable Small World Graph),其中每个顶点代表一个数据点,边表示数据点之间的相似性。
    2. 算法步骤设计
      • 插入新点:当一个新点被插入时,算法使用近似k近邻搜索(approximate kNN search)从当前图中找到与新点最近的f个点(f是预先设定的参数),然后将新点与这f个点相连。
      • 构建近似Delonay图:随着更多点的插入,图中最初形成的短距离边逐渐变为长距离边,形成了一种类似于小世界的网络结构,这种结构有利于减少搜索路径长度,加速搜索过程。
      • 贪心搜索:在搜索最近邻时,算法从一个起点开始,沿着图中与目标点距离最小的边进行贪婪搜索,直到达到一个被认为是最近邻的节点为止。
    3. 代码
      #include <iostream>
      #include <vector>
      #include <queue>
      #include <algorithm>
      #include <cmath>
      #include <ctime>
      
      using namespace std;
      
      // 常量定义
      const int MAX_POINTS = 25010; // 最大点数
      const int MAX_DIMENSIONS = 130; // 最大维度
      const int MAX_QUERIES = 110; // 最多查询数
      
      // 全局变量
      int totalPoints, dimensions, numNeighbors, numQueries; // 数据集的总点数,维度数,最近邻的数量,查询数
      float pointCoordinates[MAX_POINTS][MAX_DIMENSIONS]; // 数据点的坐标
      int maxConnections = 10; // 最大连接数(用于图的构造)
      
      // 连接图,用于存储每个点的邻居信息
      vector<int> connectionGraph[MAX_POINTS];
      
      // 建立两点之间的连接
      void establishConnection(int firstPoint, int secondPoint)
      {
          connectionGraph[firstPoint].push_back(secondPoint);
          connectionGraph[secondPoint].push_back(firstPoint);
      }
      
      // 计算两点间的欧几里得距离
      float calculateDistance(float *firstPoint, float *secondPoint)
      {
          float squaredDistance = 0;
          for (int coord = 0; coord < dimensions; coord++)
          {
              float diff = firstPoint[coord] - secondPoint[coord];
              squaredDistance += diff * diff;
          }
          return squaredDistance;
      }
      
      // 寻找最近的节点
      int locateNearestNode(int targetPoint)
      {
          int currentBestNode = 1;
          float currentBestDistance = calculateDistance(pointCoordinates[targetPoint], pointCoordinates[currentBestNode]);
      
          while (true)
          {
              int previousBestNode = currentBestNode;
              for (auto connectedNode : connectionGraph[currentBestNode])
              {
                  float distance = calculateDistance(pointCoordinates[targetPoint], pointCoordinates[connectedNode]);
      
                  if (distance < currentBestDistance) 
                  {
                      currentBestDistance = distance;
                      currentBestNode = connectedNode;
                  }
              }
      
              if (currentBestNode == previousBestNode) break;
          }
      
          return currentBestNode;
      }
      
      // 节点访问状态标记
      bool isVisited[MAX_POINTS] = {false};
      
      // 当前搜索结果,存储最近邻的信息
      vector<pair<float, int>> currentSearchResults;
      
      // 比较函数,用于对搜索结果排序
      bool compareResults(pair<float, int> a, pair<float, int> b)
      {
          if (a.first == b.first) return a.second > b.second;
          else return a.first < b.first; 
      }
      
      // 寻找k个最近的邻居
      void findKNearest(int targetPoint, int k)
      {
          int startingNode = locateNearestNode(targetPoint); // 从最近的节点开始
      
          priority_queue<pair<float, int>, vector<pair<float, int>>, greater<pair<float, int>>> priorityQueue;
          vector<int> visitedBackup;
          priorityQueue.push(make_pair(calculateDistance(pointCoordinates[startingNode], pointCoordinates[targetPoint]), startingNode));
          visitedBackup.push_back(startingNode);
          currentSearchResults.clear();
          isVisited[startingNode] = true;
      
          int explorationLimit = k * max(50 / k, 2); // 探索的上限
      
          while (currentSearchResults.size() < explorationLimit && !priorityQueue.empty())
          {
              auto [distance, currentNode] = priorityQueue.top();
              currentSearchResults.push_back(make_pair(distance, currentNode));
              priorityQueue.pop();
      
              for (auto adjacentNode : connectionGraph[currentNode])
              {
                  if (isVisited[adjacentNode]) continue;
                  isVisited[adjacentNode] = true;
                  visitedBackup.push_back(adjacentNode);
                  priorityQueue.push(make_pair(calculateDistance(pointCoordinates[adjacentNode], pointCoordinates[targetPoint]), adjacentNode));
              }
          }
      
          // 清除访问状态
          for (auto node : visitedBackup) isVisited[node] = false;
      
          // 对搜索结果进行排序,并限制结果数量
          sort(currentSearchResults.begin(), currentSearchResults.end(), compareResults);
          while (currentSearchResults.size() > k) currentSearchResults.pop_back();
      }
      
      // 构建邻居图
      void constructNeighborGraph()
      {
          for (int pointIndex = 1; pointIndex <= totalPoints; pointIndex++)
          {
              if (pointIndex <= maxConnections)
              {
                  for (int previousPoint = 1; previousPoint < pointIndex; previousPoint++) establishConnection(pointIndex, previousPoint);
                  continue;
              } 
              findKNearest(pointIndex, maxConnections);
              for (int resultIndex = 0; resultIndex < currentSearchResults.size(); resultIndex++)
                  establishConnection(pointIndex, currentSearchResults[resultIndex].second);
          }
      }
      
      // 主函数
      int main()
      {
          srand(static_cast<unsigned int>(time(nullptr))); // 随机种子初始化
      
          // 读取输入参数
          cin >> totalPoints >> dimensions >> numNeighbors;
      
          // 读取数据点坐标
          for (int pointIndex = 1; pointIndex <= totalPoints; pointIndex++)
          {
              for (int coord = 0; coord < dimensions; coord++)
                  cin >> pointCoordinates[pointIndex][coord];
          }
      
          // 构建数据点之间的连接
          constructNeighborGraph();
      
          // 读取查询数量
          cin >> numQueries;
          for (int queryIndex = totalPoints + 1; queryIndex <= totalPoints + numQueries; queryIndex++)
          {
              // 读取查询点坐标
              for (int coord = 0; coord < dimensions; coord++)
                  cin >> pointCoordinates[queryIndex][coord];
      
              // 寻找最近邻并输出结果
              findKNearest(queryIndex, numNeighbors);
              for (int resultIndex = 0; resultIndex < numNeighbors; resultIndex++)
                  cout << currentSearchResults[resultIndex].second - 1 << " ";
              cout << endl;
          }
      
          return 0;
      }

  5. 基于哈希的方法:局部敏感哈希LSH
    1. LSH的基本思想是设计一组特殊的哈希函数,使得数据点如果在高维空间中彼此接近,则它们在哈希后的桶中也有较高的概率被放置在一起。换句话说,相似的点更有可能被映射到同一个桶或相邻的桶中,而不相似的点则不太可能被映射到同一个桶中。
    2. 算法设计步骤
      • 哈希函数族:LSH使用一个哈希函数族,其中每个函数都独立地将数据点映射到桶中。常见的LSH函数族包括基于随机投影的哈希(Random Projection Hashing)、MinHash等。
      • 签名:对于每个数据点,应用哈希函数族中的多个哈希函数,产生一个“签名”。签名是数据点的紧凑表示,用于近似比较。
      • 相似度估计:通过比较数据点的签名,可以估计它们在高维空间中的相似度。如果两个数据点的签名在多个哈希函数上一致,那么它们在高维空间中可能非常接近。
      • 桶的创建:将具有相似签名的数据点放入同一个桶中,这些桶可以用作近似最近邻搜索的候选集。
      • 搜索:当查询一个点时,应用相同的哈希函数族得到查询点的签名,然后在具有相似签名的桶中搜索近似最近邻。
    3. 代码
      #include <iostream>
      #include <vector>
      #include <unordered_map>
      #include <unordered_set>
      #include <random>
      #include <cmath>
      #include <cassert>
      #include <algorithm>
      
      using namespace std;
      
      // Point structure
      struct Point {
          vector<double> coord;
          int id;
      };
      
      // LSH Bucket
      struct LshBucket {
          unordered_set<int> ids; // Set of points' IDs in the bucket
      };
      
      // LSH Parameters
      const int BANDS = 20; // Number of bands in LSH
      const int ROWS_PER_BAND = 10; // Rows per band
      const int TOTAL_ROWS = BANDS * ROWS_PER_BAND; // Total number of hash functions
      
      // Global variables
      vector<Point> points;
      vector<vector<LshBucket>> lshBuckets(BANDS);
      vector<int> hashSeeds(TOTAL_ROWS); // Hash function seeds
      
      // Initialize random hash function seeds
      void initHashSeeds() {
          random_device rd;
          mt19937 gen(rd());
          uniform_int_distribution<> dis(1, 1000000);
          for (int i = 0; i < TOTAL_ROWS; ++i) {
              hashSeeds[i] = dis(gen);
          }
      }
      
      // MinHash function
      int minHash(const vector<double>& vec, int seed, int row) {
          double value = 0;
          for (int i = 0; i < vec.size(); ++i) {
              value += vec[i] * ((i + 1) * seed + row);
          }
          return static_cast<int>(value) % 1000000;
      }
      
      // Build LSH index
      void buildLshIndex() {
          for (const auto& point : points) {
              for (int band = 0; band < BANDS; ++band) {
                  int bucketId = 0;
                  for (int row = 0; row < ROWS_PER_BAND; ++row) {
                      bucketId *= 1000000;
                      bucketId += minHash(point.coord, hashSeeds[row + band * ROWS_PER_BAND], row);
                  }
                  lshBuckets[band][bucketId].ids.insert(point.id);
              }
          }
      }
      
      // Euclidean distance
      double euclideanDistance(const Point& a, const Point& b) {
          double dist = 0;
          for (size_t i = 0; i < a.coord.size(); ++i) {
              dist += pow(a.coord[i] - b.coord[i], 2);
          }
          return sqrt(dist);
      }
      
      // K-NN search using LSH
      vector<int> lshKnnSearch(const Point& query, int k) {
          vector<int> candidates;
          for (int band = 0; band < BANDS; ++band) {
              int bucketId = 0;
              for (int row = 0; row < ROWS_PER_BAND; ++row) {
                  bucketId *= 1000000;
                  bucketId += minHash(query.coord, hashSeeds[row + band * ROWS_PER_BAND], row);
              }
              const auto& bucket = lshBuckets[band][bucketId];
              for (const auto& id : bucket.ids) {
                  candidates.push_back(id);
              }
          }
          
          // Remove duplicates and sort by distance to query
          unordered_map<int, double> distances;
          for (const auto& id : candidates) {
              distances[id] = euclideanDistance(points[id], query);
          }
          vector<int> sortedIds;
          for (const auto& pair : distances) {
              sortedIds.push_back(pair.first);
          }
          sort(sortedIds.begin(), sortedIds.end(), [&distances](int a, int b) {
              return distances[a] < distances[b];
          });
          
          // Return top-k results
          vector<int> result;
          for (int i = 0; i < k && i < sortedIds.size(); ++i) {
              result.push_back(sortedIds[i]);
          }
          return result;
      }
      
      int main() {
          // Read input
          int n, d, k;
          cin >> n >> d >> k;
          points.resize(n);
          for (int i = 0; i < n; ++i) {
              points[i].coord.resize(d);
              for (int j = 0; j < d; ++j) {
                  cin >> points[i].coord[j];
              }
              points[i].id = i;
          }
      
          // Initialize hash seeds and build LSH index
          initHashSeeds();
          buildLshIndex();
      
          // Process queries
          int q;
          cin >> q;
          for (int qi = 0; qi < q; ++qi) {
              Point query(d);
              for (int j = 0; j < d; ++j) {
                  cin >> query.coord[j];
              }
              vector<int> nearestNeighbors = lshKnnSearch(query, k);
              for (int nn : nearestNeighbors) {
                  cout << nn << " ";
              }
              cout << endl;
          }
      
          return 0;
      }

  • 18
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值