实现了KD平衡树的程序,由于MATLAB实现需要用到引用类型或者采用循环实现(见MATLAB的KDTreeSearcher.m),因此采用C#来实现
using System;
using System.Collections.Generic;
using System.Linq;
namespace KNNSearch
{
///
/// Description of KNN.
///
public class KNN
{
///
/// 叶子节点点的个数
///
private int leafnum = 30;
///
/// 待分类数据
///
private List
rawData;
///
/// 生成原始数据
///
private void GeneralRawData()
{
if (rawData == null)
{
rawData = new List
();
}
else
{
rawData.Clear();
}
Random r = new Random();
for (int i = 0; i < 500; i++)
{
rawData.Add(new Point() { X = r.NextDouble(), Y = r.NextDouble(), Z = r.NextDouble() });
}
}
///
/// 创建KD树
///
///
///
private Node CreateKDTree(List
data)
{
// 创建根节点
Node root = new Node();
// 添加当前节点数据
root.nodeData = data;
// 如果节点的数据数量小于叶子节点的数量限制,则当前节点为叶子节点
if (data.Count <= leafnum)
{
root.leftNode = null;
root.rightNode = null;
root.point = double.NaN;
root.splitaxis = -1;
return root;
}
// 找到分割轴
int splitAxis = GetSplitAxis(data);
// 分割数据
Tuple
, List
> dataSplit = GetSplitNum(data, splitAxis); root.splitaxis = splitAxis; root.point = dataSplit.Item1; root.leftNode = CreateKDTree(dataSplit.Item2); root.rightNode = CreateKDTree(dataSplit.Item3); return root; } private Tuple
, List
> GetSplitNum(List
data, int splitAxis) { // 对数据按照第splitAxis排序 var data0 = splitAxis == 0 ? (data.OrderBy(x => x.X)).ToList() : (splitAxis == 1 ? (data.OrderBy(x => x.Y)).ToList() : (data.OrderBy(x => x.Z)).ToList()); int half = data0.Count / 2; List
leftdata = new List
(); List
rightdata = new List
(); for (int i = 0; i < data0.Count; i++) { if (i <= half) { leftdata.Add(data0[i]); } else { rightdata.Add(data0[i]); } } double splitnum = splitAxis == 0 ? data[half].X + data[half + 1].X : (splitAxis == 1 ? data[half].Y + data[half + 1].Y : data[half].Z + data[half + 1].Z); return new Tuple
, List
>(splitnum / 2, leftdata, rightdata); } ///
/// 获取分割轴编号 /// ///
///
private int GetSplitAxis(List
data) { // 设定数据范围最大的轴作为分割轴(也有其他的方式,如方差,或者轮流的方式) var xData = data.Select(item => item.X); var yData = data.Select(item => item.Y); var zData = data.Select(item => item.Z); List
ranges = new List
(); ranges.Add(xData.Max() - xData.Min()); ranges.Add(yData.Max() - yData.Min()); ranges.Add(zData.Max() - zData.Min()); var sorted = ranges.Select((x, i) => new KeyValuePair
(x, i)).OrderByDescending(x => x.Key).ToList(); return sorted.Select(x => x.Value).ToList()[0]; ; } public KNN() { GeneralRawData(); Node node = CreateKDTree(rawData); } } ///
/// Description of Node. /// public class Node { ///
/// 切分的阈值点 /// public double point; ///
/// 左节点 /// public Node leftNode; ///
/// 右节点 /// public Node rightNode; ///
/// 节点包含的数据 /// public List
nodeData; ///
/// 分割轴 /// public int splitaxis; public Node() { } } public class Point { public double X; public double Y; public double Z; } }