概念:
k 近邻法 (k-nearest neighbor,k-NN) 是一种基本分类与回归方法。 主要思想:假定给一个训练数据集,其中实例标签已定,当输入新的实例时,可根据其最近的k个训练实例标签,预测新实例对应的标注信息。
应用领域
- 分类:根据周围已知类别的样本来预测新样本的类别。
- 回归:预测连续值(如房价、温度等)。
当k=3,绿色圆点属于红色三角形类别,当k=5,绿色圆点属于蓝色正方形类别。
算法
数学表示
-
数据集:假设我们有一个训练数据集 D={(x1,y1),(x2,y2),...,(xn,yn)},其中每个 ��xi 是一个特征向量,yi 是对应的标签。
-
距离度量:选择一个距离度量函数 d(x,x′) 来计算两个点之间的距离。常见的距离度量包括:
- 欧几里得距离 (Euclidean distance):2d(x,x′)=∑i=1m(xi−xi′)2,其中m 是特征的数量。
- 曼哈顿距离 (Manhattan distance):∣d(x,x′)=∑i=1m∣xi−xi′∣。
-
查找最近的K个邻居:对于一个新的未标记点 x,找到训练集 D 中距离 x 最近的K个点。这可以表示为:Nk(x)={(x′,y′)∈D∣x′ 是离 x 第k近的点}。
-
分类或回归:
- 分类:y=mode{y′∣(x′,y′)∈Nk(x)},即选择 Nk(x) 中最常见的类别标签作为预测结果。
- 回归:y=k1∑(x′,y′)∈Nk(x)y′,即取 Nk(x) 中所有点的标签的平均值作为预测结果。
参数选择
- K值的选择:K的选择会影响算法的性能。较小的K值意味着模型对噪声更敏感,而较大的K值则可能使模型无法捕捉到数据的局部特征。
Kd树
K最近邻(KNN)算法中。KD树是一种二叉树,每个节点代表K维空间中的一个数据点,使得搜索特定数据点变得更加高效。
KD树的构建过程是递归的,其核心步骤如下:
-
选择轴:在每个节点,选择一个维度作为“切分轴”。通常,这个选择是轮流进行的,例如,在一个二维空间中,可能依次选择X轴、Y轴。
-
分割数据:根据切分轴上的中位数将数据分割成两部分,分别构成当前节点的左子树和右子树。切分的目标是平衡树结构,以优化搜索效率。
-
递归构建:对左右子树重复上述过程,直到每个子区域只有一个数据点。
搜索过程
在KD树中搜索最近邻的过程涉及以下步骤:
-
向下搜索树:从根节点开始,向下遍历树,直到找到最近邻候选点。
-
回溯检查:然后回溯到父节点,检查另一侧的子树是否有更近的点。如果有,更新最近邻候选。
-
最近邻确定:继续回溯和检查,直到根节点。此时的最近邻候选即为最终的最近邻。
import numpy as np
class Node:
"""
定义KD树的节点
point: 节点包含的数据点
left, right: 节点的左右子节点
"""
def __init__(self, point, left=None, right=None):
self.point = point
self.left = left
self.right = right
def build_kdtree(points, depth=0):
"""
构建KD树
points: 数据点的列表
depth: 当前深度(用于选择轴)
"""
n = len(points)
if n == 0:
return None
k = len(points[0]) # 假设所有点的维度相同
axis = depth % k # 选择轴
# 按照当前轴的坐标对点进行排序
sorted_points = sorted(points, key=lambda point: point[axis])
median_index = n // 2 # 找到中位数的索引
# 递归创建左右子树
return Node(
point=sorted_points[median_index],
left=build_kdtree(sorted_points[:median_index], depth + 1),
right=build_kdtree(sorted_points[median_index + 1:], depth + 1)
)
def closest_point(root, point, depth=0):
"""
在KD树中查找最近的点
root: KD树的根节点
point: 查询的点
depth: 当前深度
"""
if root is None:
return None
k = len(point)
axis = depth % k
next_branch = None
opposite_branch = None
# 根据当前轴的坐标选择下一分支
if point[axis] < root.point[axis]:
next_branch = root.left
opposite_branch = root.right
else:
next_branch = root.right
opposite_branch = root.left
# 检查当前分支中的最近点
best = closer_point(point,
closest_point(next_branch, point, depth + 1),
root.point)
# 检查是否需要在对面分支中搜索
if distance(point, best) > abs(point[axis] - root.point[axis]):
best = closer_point(point,
closest_point(opposite_branch, point, depth + 1),
best)
return best
def distance(point1, point2):
"""
计算两点之间的欧几里得距离
"""
return np.sqrt(sum((p1 - p2) ** 2 for p1, p2 in zip(point1, point2)))
def closer_point(reference, point1, point2):
"""
比较两点哪个更接近参考点
"""
if point1 is None:
return point2
if point2 is None:
return point1
if distance(reference, point1) < distance(reference, point2):
return point1
else:
return point2
# 示例使用
points = [[2, 3], [5, 4], [9, 6], [4, 7], [8, 1], [7, 2]]
tree = build_kdtree(points)
point = [3, 4.5]
print("离点", point, "最近的点是", closest_point(tree, point))
离点 [3, 4.5] 最近的点是 [2, 3]
整体代码为
import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
# 定义KD树节点类和相关函数
class Node:
def __init__(self, point, left=None, right=None):
self.point = point
self.left = left
self.right = right
def build_kdtree(points, depth=0):
if not points:
return None
k = len(points[0])
axis = depth % k
points.sort(key=lambda x: x[axis])
median = len(points) // 2
return Node(
point=points[median],
left=build_kdtree(points[:median], depth + 1),
right=build_kdtree(points[median + 1:], depth + 1)
)
def distance_squared(point1, point2):
return sum((p1 - p2) ** 2 for p1, p2 in zip(point1, point2))
def k_nearest_neighbors(root, point, k, depth=0, verbose=False):
if root is None:
return []
k_dimensions = len(point)
axis = depth % k_dimensions
next_branch = root.left if point[axis] < root.point[axis] else root.right
opposite_branch = root.right if point[axis] < root.point[axis] else root.left
if verbose:
print(f"Depth: {depth}, Axis: {axis}, Checking point: {root.point}")
best = k_nearest_neighbors(next_branch, point, k, depth + 1, verbose)
if len(best) < k or distance_squared(point, root.point) < distance_squared(point, best[-1][0]):
best.append((root.point, distance_squared(point, root.point)))
best.sort(key=lambda x: x[1])
if len(best) < k or abs(point[axis] - root.point[axis]) < best[-1][1]:
best.extend(k_nearest_neighbors(opposite_branch, point, k, depth + 1, verbose))
best.sort(key=lambda x: x[1])
return best[:k]
def get_neighbor_labels(neighbors, X, y):
labels = []
for neighbor in neighbors:
index = np.where(np.all(X == neighbor[0], axis=1))[0][0]
labels.append(y[index])
return labels
# 加载Iris数据集
iris = datasets.load_iris()
X = iris.data[:, :2] # 只取前两个特征
y = iris.target
# 构建KD树
kd_tree = build_kdtree(X.tolist())
# 输入点的坐标
input_point = [5.6, 6.3]
# 使用KD树查找输入点的最近邻
k = 5 # 查找最近的5个邻居
neighbors = k_nearest_neighbors(kd_tree, input_point, k, verbose=True)
# 获取最近邻的类别
neighbor_labels = get_neighbor_labels(neighbors, X, y)
# 预测类别:选择最常见的类别
predicted_label = max(set(neighbor_labels), key=neighbor_labels.count)
predicted_class = iris.target_names[predicted_label]
print(f"Input point {input_point} is predicted to be '{predicted_class}'")
# 绘制结果
plt.figure(figsize=(10, 6))
plt.scatter(X[:, 0], X[:, 1], color='lightgray', label='Data points')
plt.scatter([n[0][0] for n in neighbors], [n[0][1] for n in neighbors], color='blue', label='K Nearest Neighbors')
plt.scatter(input_point[0], input_point[1], color='red', label='Input Point')
plt.xlabel('Sepal Length')
plt.ylabel('Sepal Width')
plt.title(f'KD Tree KNN Search on Iris Dataset (Predicted: {predicted_class})')
plt.legend()
plt.show()