【机器学习实战】逻辑回归----digits手写数字分类

【导入库和数据集】

和线性回归一样,首先导入所需要用到的库和数据集。
导入库:

##用于可视化图表
import matplotlib.pyplot as plt
##用于做科学计算
import numpy as np
##用于做数据分析
import pandas as pd
##用于加载数据或生成数据等
from sklearn import datasets
##加载线性模型
from sklearn import linear_model
###用于交叉验证以及训练集和测试集的划分
from sklearn.cross_validation import train_test_split
from sklearn.model_selection import cross_val_predict
#from sklearn.cross_validation import cross_val_score
###这个模块中含有评分函数,性能度量,距离计算等
from sklearn import metrics
###用于做数据预处理
from sklearn import preprocessing

数据集:
这次数据集还是选用sklearn中提供的小数据集——digits手写字体数据集。先看一下官方介绍:
这里写图片描述
通过官方的数据集介绍我们可以知道,这个digits手写数据集其实是1797组8*8的手写数字图像的像素点集合。有10个分类,代表了“0,1,2,…,9”这是个数字。特征维度为64,对应了每组数据的8*8个像素点。大概知道了这些,我们可以具体查看一下。

digits = datasets.load_digits()#导入digits数据集
print(digits.keys())#查看digits中有哪些属性
输出为:dict_keys(['images', 'target_names', 'DESCR', 'data', 'target'])
(1797, 64)
[0 1 2 3 4 5 6 7 8 9]
0
[[  0.   0.   5.  13.   9.   1.   0.   0.]
 [  0.   0.  13.  15.  10.  15.   5.   0.]
 [  0.   3.  15.   2.   0.  11.   8.   0.]
 [  0.   4.  12.   0.   0.   8.   8.   0.]
 [  0.   5.   8.   0.   0.   9.   8.   0.]
 [  0.   4.  11.   0.   1.  12.   7.   0.]
 [  0.   2.  14.   5.  10.  12.   0.   0.]
 [  0.   0.   6.  13.  10.   0.   0.   0.]]
1
[[  0.   0.   0.  12.  13.   5.   0.   0.]
 [  0.   0.   0.  11.  16.   9.   0.   0.]
 [  0.   0.   3.  15.  16.   6.   0.   0.]
 [  0.   7.  15.  16.  16.   2.   0.   0.]
 [  0.   0.   1.  16.  16.   3.   0.   0.]
 [  0.   0.   1.  16.  16.   6.   0.   0.]
 [  0.   0.   1.  16.  16.   6.   0.   0.]
 [  0.   0.   0.  11.  16.  10.   0.   0.]]

digits数据集中有1797个数据,分类标签为0~9,打印第一和第二个图像的标签和数据看看,可以发现每一个像素点在0~16之间。不过这个看起来感觉不够直观,可以画图来看。

plt.gray()
for i in range(0,2):
    plt.matshow(digits.images[i])
    plt.show()
    print(digits.target[i])

输出为:
这里写图片描述

大概能看出数字的样子。我们可以将这些图像和它的标签打印到一起来看。

fig=plt.figure(figsize=(8,8))
fig.subplots_adjust(left=0,right=1,bottom=0,top=1,hspace=0.05,wspace=0.05)
for i in range(30):
    ax=fig.add_subplot(6,5,i+1,xticks=[],yticks=[])
    ax.imshow(digits.images[i],cmap=plt.cm.binary,interpolation='nearest')
    ax.text(0,7,str(digits.target[i]))
plt.show()

输出为:
这里写图片描述

【二分类问题】

从上面对数据集的分析,我们知道digits数据集有10个分类,我们将0~4当做一类,5~9当做另一类,问题就变为了二分类问题。
导入数据:
获得数据集的输入和输出,并且划分训练集和测试集。

digits_X = digits.data   ##获得数据集中的输入
digits_y = digits.target ##获得数据集中的输出,即标签(也就是类别)
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值