sklearn KNN算法实现鸢尾花分类
- 编译环境
- python 3.6
- 使用到的库
- sklearn
简介
本文利用sklearn中自带的数据集(鸢尾花数据集),并通过KNN算法实现了对鸢尾花的分类。
KNN算法核心思想:如果一个样本在特征空间中的K个最相似(最近临)的样本中大多数属于某个类别,则该样本也属于这个类别。
sklearn库介绍
自2007年发布以来,scikit-learn已经成为最给力的Python机器学习库(library)了。scikit-learn支持的机器学习算法包括分类,回归,降维和聚类。还有一些特征提取(extracting features)、数据处理(processing data)和模型评估(evaluating models)的模块。
安装:
pip install sklearn
鸢尾花数据集介绍
sklearn.datasets.load_iris() # 加载并返回鸢尾花数据集
</tr>
<tr>
<td>特征</td>
<td>4</td>
</tr>
<tr>
<td>样本数量</td>
<td>150</td>
</tr>
<tr>
<td>每个类别数量</td>
<td>50</td>
</tr>
</tbody>
名称 | 数量 |
---|---|
类别 | 3 |
sklearn.neighbors.KNeighborsClassifier(n_neighbors=5,algorithm='auto')
- n_neighbors:int,可选(默认= 5),k_neighbors查询默认使用的邻居数
- algorithm:{‘auto’,‘ball_tree’,‘kd_tree’,‘brute’},可选用于计算最近邻居的算法:‘ball_tree’将会使用 BallTree,‘kd_tree’将使用 KDTree。‘auto’将尝试根据传递给fit方法的值来决定最合适的算法。 (不同实现方式影响效率)
获取鸢尾花数据
from sklearn.datasets import load_iris
def get_iris_data(self):
iris = load_iris()
iris_data = iris.data # 鸢尾花特征值(4个)
iris_target = iris.target # 鸢尾花目标值(类别)
return iris_data, iris_target