简单来写:
def fit(train, k):
self.train = train
self.k = k
def predict(test):
# a. 从训练数据train中获取和当前数据test距离最近的k个样本
neighbors = fetch_k_neighbors(self.train, test, self.k)
# b. 合并这K个最近样本,得到预测值
predict_label = calc_predict_label(neighbors)
return predict_label
复杂来写:
def fit(train, k):
self.train = train
self.k = k
def predict(test):
result = []
for x in test:
# a. 从训练数据train中获取和当前数据x距离最近的k个样本
neighbors = fetch_k_neighbors(self.train, x, self.k)
# b. 合并这K个最近样本,得到预测值
# b1. 统计一下各个类别label出现的次数
label_2_count_dict = {}
for neighbor in neighbors:
# b11. 获取当前样本neighbor的标签值
label = neighbor.label
# b12. 将这个label添加到字典中
if label not in label_2_count_dict:
label_2_count_dict[label] = 1
else:
label_2_count_dict[label] += 1
# b2. 从这个字典中获取出现次数最多的label标签值作为预测值
max_label_count = 0
max_label = None
for label in label_2_count_dict:
# 获取当前label对应出现的count数量
count = label_2_count_dict[label]
# 将当前count和最大值进行比较,选择/保留最大的count
if count > max_label_count:
max_label_count = count
max_label = label
# b3. 将预测值添加到集合中
result.append(max_label)
return result
KNN伪代码(简易版和复杂版)
最新推荐文章于 2023-07-03 21:23:49 发布