from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import GridSearchCV
def iris_KNN():
# 1,获取数据
iris = load_iris()
# 2,划分数据集
x_train, x_test, y_train, y_test = train_test_split(iris.data, iris.target, random_state=69)
# 3,标准化
transfer = StandardScaler()
x_train = transfer.fit_transform(x_train)
x_test = transfer.transform(x_test)
# 4,knn算法预估器
estimator = KNeighborsClassifier(n_neighbors=5)
# 添加网格搜索和交叉验证
param_dict = {"n_neighbors": [1, 3, 5, 7, 9]}
estimator = GridSearchCV(estimator, param_grid=param_dict, cv=10)
# 训练
estimator.fit(x_train, y_train)
# 5,模型评估
# 1
y_predict = estimator.predict(x_test)
print("预估\n", y_predict == y_test)
# 2
score = estimator.score(x_test, y_test)
print("score:\n", score)
return None
if __name__ == "__main__":
iris_KNN()
鸢尾花数据集的KNN算法模型的训练和预估
于 2022-03-05 19:56:42 首次发布