1 项目描述
使用逻辑回归算法来对鸢尾花进行分类;
数据集包括训练数据train.txt和测试数据test.txt;测试数据中,每个样本包括特定的几个特征参数,最后是一个类别标签,而测试数据中的样本则只包括了特征参数
2 逻辑回归:鸢尾花数据集分类
2.1 鸢尾花数据信息
-
Sepal length: 花萼长度
-
Sepal width: 花萼宽度
-
Petal length: 花瓣长度
-
Petal width: 花瓣宽度
2.2 鸢尾花分类
2.3 问题描述
如果: 花萼长度,花萼宽度, 花瓣长度,花瓣宽度为5.1, 3.5, 1.4, 0.2
问:是什么花
3 分析问题
3.1 加载数据集
def load_data():
"""
加载数据集
:return:
X: 花瓣宽度
Y: 鸢尾花类型
"""
# 加载sklearn包自带的鸢尾花数据;
iris = datasets.load_iris()
# # 查看鸢尾花的数据集
# print(iris)
# # 查看鸢尾花的key值;
# # dict_keys(['data', 'target', 'target_names', 'DESCR','feature_names', 'filename'])
# print(iris.keys())
# # 获取鸢尾花的特性: ['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)']
# print(iris['feature_names'])
# print(iris['data'])
# print(iris['target'])
# 因为花瓣的相关系数比较高, 所以分类效果比较好, 所以我们就用花瓣宽度当作x;
X = iris['data'][:, 3:]
# 获取分类的结果
Y = iris['target']
return X, Y
3.2 可视化展示
图形配置
def configure_plt(plt):
"""
配置图形的坐标表信息
"""
# 获取当前的坐标轴, gca = get current axis
ax = plt.gca()
# 设置x轴, y周在(0, 0)的位置
ax.spines['bottom'].set_position(('data', 0))
ax.spines['left'].set_position(('data', 0))
# 绘制x,y轴说明
plt.xlabel('petal width (cm)') # 花瓣宽度
plt.ylabel('target') # 鸢尾花类型
return plt
- 绘图
def