Tensorflow2计算二维点集的K近邻

# -*- coding: utf-8 -*-
"""
Created on 2021.06.17
@author: xi'an Li
"""
import tensorflow as tf
import numpy as np


def pairwise_distance(point_set):
    """Compute pairwise distance of a point cloud.
        Args:
          (x-y)^2 = x^2 - 2xy + y^2
          point_set: tensor (num_points, dims2point)
        Returns:
          pairwise distance: (num_points, num_points)
    """
    point_set_shape = point_set.get_shape().as_list()
    assert(len(point_set_shape)) == 2

    point_set_transpose = tf.transpose(point_set, perm=[1, 0])
    point_set_inner = tf.matmul(point_set, point_set_transpose)
    point_set_inner = -2 * point_set_inner
    point_set_square = tf.reduce_sum(tf.square(point_set), axis=-1, keepdims=True)
    point_set_square_transpose = tf.transpose(point_set_square, perm=[1, 0])
    return point_set_square + point_set_inner + point_set_square_transpose


def knn_includeself(dist_matrix, k=20):
    """Get KNN based on the pairwise distance.
        How to use tf.nn.top_k(): https://blog.csdn.net/wuguangbin1230/article/details/72820627
      Args:
        pairwise distance: (num_points, num_points)
        k: int

      Returns:
        nearest neighbors: (num_points, k)
      """
    neg_dist = -1.0*dist_matrix
    _, nn_idx = tf.nn.top_k(neg_dist, k=k)  # 这个函数的作用是返回 input 中每行最大的 k 个数,并且返回它们所在位置的索引
    return nn_idx


def knn_excludeself(dist_matrix, k=20):
    """Get KNN based on the pairwise distance.
      Args:
        pairwise distance: (num_points, num_points)
        k: int

      Returns:
        nearest neighbors index: (num_points, k)
      """
    neg_dist = -1.0*dist_matrix
    k_neighbors = k+1
    _, knn_idx = tf.nn.top_k(neg_dist, k=k_neighbors)  # 这个函数的作用是返回 input 中每行最大的 k 个数,并且返回它们所在位置的索引
    nn_idx = knn_idx[:, 1: k_neighbors]
    return nn_idx

def get_kneighbors_2DTensor(point_set, nn_idx):
    """Construct neighbors feature for each point
        Args:
        point_set: (num_points, dim)
        nn_idx: (num_points, k_num)
        num_points: the number of point
        k_num: the number of neighbor

        Returns:
        neighbors features: (num_points, k_num, dim)
      """
    shape2point_set = point_set.get_shape().as_list()
    assert(len(shape2point_set) == 2)
    point_set_neighbors = tf.gather(point_set, nn_idx)
    return point_set_neighbors

if __name__ == "__main__":
    num2points = 50
    kneighbor = 3
    indim = 2
    outdim = 1
    hiddens = (4, 8, 16)
    # x0 = np.array([[1, 2],
    #               [2, 3],
    #               [3, 5],
    #               [8, 9],
    #               [5, 2],
    #               [4, 1],
    #               [3, 2],
    #               [9, 3],
    #               [8, 5],
    #               [7, 1]], dtype=np.float)
    x0 = np.array([[1, 2],
                   [2, 3],
                   [3, 5],
                   [8, 9],
                   [5, 2]], dtype=np.float)
    print(x0)
    # x_point = np.random.rand(num2points, indim)
    # X = tf.Variable(initial_value=x_point, trainable=True, name='X')
    X = tf.Variable(initial_value=x0, trainable=True, name='X')
    print(X)
    adj_mat = pairwise_distance(X)
    print('adj_mat:', adj_mat)
    k_index = knn_excludeself(adj_mat, k=kneighbor)
    print('k_index:', k_index)
    neighbors = get_kneighbors_2DTensor(X, k_index)
    print('neighbors[0]', neighbors[0, :, :])

结果:
<tf.Variable ‘X:0’ shape=(5, 2) dtype=float64, numpy=
array([[1., 2.],
[2., 3.],
[3., 5.],
[8., 9.],
[5., 2.]])>
adj_mat: tf.Tensor(
[[ 0. 2. 13. 98. 16.]
[ 2. 0. 5. 72. 10.]
[13. 5. 0. 41. 13.]
[98. 72. 41. 0. 58.]
[16. 10. 13. 58. 0.]], shape=(5, 5), dtype=float64)
k_index: tf.Tensor(
[[1 2 4]
[0 2 4]
[1 0 4]
[2 4 1]
[1 2 0]], shape=(5, 3), dtype=int32)
neighbors[0] tf.Tensor(
[[2. 3.]
[3. 5.]
[5. 2.]], shape=(3, 2), dtype=float64)

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值