# -*- coding: utf-8 -*-
"""
案例:水果识别
任务:通过knn算法对水果进行识别
"""
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score
# 1.数据加载
fruits_df = pd.read_table('./data/fruit_data_with_colors.txt')
print(fruits_df.info())
print()
print(fruits_df.describe())
print()
print(fruits_df.head())
# 2.创建目标标签和名称的字典
fruit_name_dict = dict(zip(fruits_df['fruit_label'], fruits_df['fruit_name']))
# 3.划分数据集
X = fruits_df[['mass', 'width', 'height', 'color_score']]
y = fruits_df['fruit_label']
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=1/4, random_state=0)
print("\n数据集样本数:{},训练集样本数:{},测试集样本数:{}".format(X.shape[0], X_train.shape[0], X_test.shape[0]))
# 4.建立选择模型
knn = KNeighborsClassifier(n_neighbors=5)
# 5.训练模型
knn.fit(X_train, y_train)
# 6.测试模型
y_pred = knn.predict(X_test)
acc = accuracy_score(y_test, y_pred)
print("\nk为5时准确率为:",acc)
# 7.查看k值对结果的影响
k_range = range(1, 20)
acc_scores = []
for k in k_range:
knn = KNeighborsClassifier(n_neighbors=k)
knn.fit(X_train, y_train)
acc_scores.append(knn.score(X_test, y_test))
plt.figure()
plt.xlabel('k')
plt.ylabel('accuracy')
plt.plot(k_range, acc_scores, marker='o')
plt.xticks([0, 5, 11, 15, 21])
plt.savefig('./output/k_validation.png')
plt.show()