自己参考Bobo老师写得代码:
主要分为四个文件: knn.py中实现KNN算法、model_selection.py封装了样本数据的一些工具方法,比如切分为训练集和测试集;
metrics用来对模型进行评估、client用来调用算法进行运行
# -*- encoding: utf-8 -*-
"""
实现KNN的分类算法
"""
import numpy as np
from math import sqrt
from collections import Counter
from metrics import accuracy_score
class KnnClassifier(object):
"""
K-近邻算法,(K Nearest Neighbour),简称KNN
"""
def __init__(self, k):
"""
K表示
:param k: 表示参考的个数
"""
self.k = k
def fit(self, X_train, y_train):
"""
利用输入的样本集进行训练KNN算法
:param X_train: X 训练样本集
:param y_train: y
:return:
"""
self.X_train = X_