各种机器学习方法(线性回归、支持向量机、决策树、朴素贝叶斯、KNN算法、逻辑回归)实现手写数字识别并用准确率、召回率、F1进行评估

本文转自:http://blog.csdn.net/net_wolf_007/article/details/51794254

前面两章对数据进行了简单的特征提取及线性回归分析。识别率已经达到了85%, 完成了数字识别的第一步:数据探测。
这一章要做的就各种常用机器学习算法来对数据进行测试,并总结规律,为后续进一步提供准确率作准备。
这单选取的算法有:(后面有时间再对每个算法单独作分析总结介绍):
  1. 线性回归
  2. 支持向量机
  3. 决策树
  4. 朴素贝叶斯
  5. KNN算法
  6. 逻辑回归
以测试样本的最后一千个数作为测试样例,其它的作为训练样例。 数据结果为测试样例的识别结果。

使用到的统计概念有:precision,recall及f1-score(参见文章 机器学习结果统计-准确率、召回率,F1-score
先看测试总结数据,后面会提供完整代码及数据提供。使用的机器学习算法库为: http://scikit-learn.org/

算法测试数据
对算法进行测试,测试数据如下:

测试算法中文precisionrecallf1-score样本数所花时间(秒)
LinearRegression线性回归0.840.840.8410002.402
SVC支持向量机0.850.850.85100022.72
DecisionTree决策树0.810.810.8110000.402
Bayes朴素贝叶斯0.780.770.7710000.015
KNNKNN算法0.860.860.8610000.374
LogisticRegression逻辑回归0.820.820.8210002.419

总结:
算法识别结果想差不多,除了贝叶斯之外都在80%以上。 
时间花销上SVC算法花得最多。而贝叶斯花得最小。 
在手写识别上,如果样本量够多,相比之下KNN算法会更确。

分析:
由于贝叶期是基于概率统计的算法。 计算时间会很快。而由于缺乏对特征之间的关心的支持,识别率会弱一些。

各数字识别统计
 线性回归支持向量机决策树基本贝叶斯算法KNN算法逻辑回归  
数字recallprecisionrecallprecisionrecallprecisionrecallprecisionrecallprecisionrecallprecisionrecall 平均precision平均
00.930.830.930.880.920.90.920.880.950.870.950.890.9330.875
10.960.910.970.970.960.960.970.840.990.980.940.940.9650.933
20.820.820.840.760.80.740.590.730.810.810.730.760.7650.770
30.790.720.750.740.710.740.710.610.810.710.730.670.7500.698
40.680.870.730.840.70.770.640.830.750.850.680.780.6970.823
50.730.780.750.80.660.730.670.630.780.840.80.80.7320.763
60.930.930.920.880.890.880.910.840.920.910.890.910.9100.892
70.890.830.880.850.780.720.740.780.870.880.860.820.8370.813
80.830.830.830.850.780.790.820.770.810.90.830.80.8170.823
90.820.860.820.850.860.830.720.80.840.810.790.80.8080.825

数据总结:
  1. 整体数据中,1,6的识别率最高,而2, 3, 4, 5识别率最低。说明数据中比较难区分2, 3, 4, 5的属性。
  2. 注意线性回归中,0的recall 大于precision10个百分点,意味着把别的数据错误识别成0成的较多。
  3. 4中 recall小于precision 20%面分则说明识别4的特征不明显。



终合上面分析结果得到后续计划:
  1. 继续分析数据,寻找更有用的特征值。
  2. 使用KNN来作为数据分析算法。


测试代码
[python]  view plain  copy
  1. import numpy as np  
  2. from sklearn.naive_bayes import GaussianNB  
  3. from sklearn.neighbors import KNeighborsClassifier  
  4. from sklearn.tree import DecisionTreeClassifier  
  5. from sklearn.linear_model import LogisticRegression  
  6. from sklearn.linear_model import LinearRegression  
  7. from sklearn.svm import SVC  
  8. import functools  
  9. from datetime import datetime  
  10. from time import clock  
  11.   
  12. from sklearn import metrics  
  13. from tools import load_data, load_source, show_source  
  14.   
  15.   
  16. def log_time(fn):  
  17.     @functools.wraps(fn)  
  18.     def wrapper():  
  19.         start = clock()  
  20.         ret = fn()  
  21.         end = clock()  
  22.         print("{}  use time: {:.3f} s" .format (fn.__name__,  end-start))  
  23.         return ret  
  24.     return wrapper  
  25.   
  26. #加载训练数据  
  27. data_x, data_y = load_data("train.txt")  
  28.   
  29. # 加载原始数据  
  30. source_data = load_source("train.csv")  
  31.   
  32. # 打印数据长度  
  33. print("len", len(data_x), len(data_y))  
  34.   
  35. # 设置测试数据数量  
  36. LEN = -1000  
  37. # 划分训练数据和测试数据 注: 当前测试中用到测试数据训练集(train.csv)的数据, 而暂时没有用到测试数据集的数据(test.csv)  
  38. x_train, y_train = data_x[:LEN], data_y[:LEN]  
  39. x_test, y_test = data_x[LEN:], data_y[LEN:]  
  40.  
  41. @log_time  
  42. def tran_LinearRegression():  
  43.     # 定义45个线性分类器,并训练数据,每个分类器只对两个数字进行识别  
  44.     RegressionDict = {}  
  45.     for i in range(10):  
  46.         for j in range(i+110):  
  47.             regr = LinearRegression()  
  48.             RegressionDict["{}-{}".format(i, j)] = regr  
  49.             x_train_tmp = np.array([x_train[index] for index,  y in enumerate(y_train) if y in [i, j]])  
  50.             y_train_tmp = np.array([0 if y == i else 1 for y in y_train if y in [i, j]])  
  51.             regr.fit(x_train_tmp, y_train_tmp)  
  52.   
  53.     # 初始化计数器  
  54.     ret_counter = []  
  55.     for i in range(len(x_test)):  
  56.         ret_counter.append({})  
  57.     # 预测数据,并把结果放到计数器中  
  58.     tmp_dict = {}  
  59.     for key, regression in RegressionDict.items():  
  60.         a, b = key.split('-')  
  61.         y_test_predict = regression.predict(x_test)  
  62.         tmp_dict[key] = [a if item <= 0.5 else b for item in y_test_predict]  
  63.         for i, item in enumerate(tmp_dict[key]):  
  64.             ret_counter[i][item] = ret_counter[i].get(item, 0) + 1  
  65.   
  66.     predict = []  
  67.     for i, item in enumerate(y_test):  
  68.         predict.append(int(sorted(ret_counter[i].items(), key=lambda x:x[1],reverse=True)[0][0]))  
  69.   
  70.     return predict  
  71.   
  72.   
  73. print("\nLinearRegression")  
  74. predicted = tran_LinearRegression()  
  75. print(metrics.classification_report(y_test, predicted))  
  76. print("count:", len(y_test), "ok:", sum([1 for item in range(len(y_test)) if y_test[item]== predicted[item]]))  
  77.   
  78. # 其它常用分类器测试  
  79. map_predictor = {  
  80.     "LogisticRegression":LogisticRegression(),  
  81.     "bayes":GaussianNB(),  
  82.     "KNN": KNeighborsClassifier(),  
  83.     "DecisionTree":DecisionTreeClassifier(),  
  84.     "SVC":SVC()  
  85. }  
  86.   
  87. for key, model in map_predictor.items():  
  88.   
  89.     start = clock()  
  90.     print("start: ",start, key, datetime.utcnow())  
  91.     model.fit(x_train, y_train)  
  92.     end = clock()  
  93.     print("end: ",end,  key, datetime.utcnow())  
  94.     print("{}  use time: {:.3f} s".format(key,  end-start))  
  95.     predicted = model.predict(x_test)  
  96.     print(metrics.classification_report(y_test, predicted))  
  97.     print("count:", len(y_test), "ok:", sum([1 for item in range(len(y_test)) if y_test[item]== predicted[item]]))  

输出结果

[python]  view plain  copy
  1. len 42000 42000  
  2.   
  3. LinearRegression  
  4. /Library/Frameworks/Python.framework/Versions/3.4/lib/python3.4/site-packages/scipy/linalg/basic.py:884: RuntimeWarning: internal gelsd driver lwork query error, required iwork dimension not returned. This is likely the result of LAPACK bug 0038, fixed in LAPACK 3.2.2 (released July 212010). Falling back to 'gelss' driver.  
  5.   warnings.warn(mesg, RuntimeWarning)  
  6. tran_LinearRegression  use time: 2.402 s  
  7.              precision    recall  f1-score   support  
  8.   
  9.         0.0       0.84      0.93      0.89        92  
  10.         1.0       0.91      0.96      0.93       127  
  11.         2.0       0.79      0.84      0.81        97  
  12.         3.0       0.73      0.77      0.75        95  
  13.         4.0       0.87      0.68      0.77       111  
  14.         5.0       0.79      0.73      0.76        85  
  15.         6.0       0.92      0.94      0.93       103  
  16.         7.0       0.83      0.89      0.86       101  
  17.         8.0       0.83      0.83      0.83        88  
  18.         9.0       0.87      0.82      0.85       101  
  19.   
  20. avg / total       0.84      0.84      0.84      1000  
  21.   
  22. count: 1000 ok: 843  
  23. start:  3.398754 SVC 2016-06-30 11:06:01.412944  
  24. end:  26.118339 SVC 2016-06-30 11:06:24.145873  
  25. SVC  use time: 22.720 s  
  26.              precision    recall  f1-score   support  
  27.   
  28.         0.0       0.88      0.93      0.91        92  
  29.         1.0       0.97      0.97      0.97       127  
  30.         2.0       0.76      0.84      0.80        97  
  31.         3.0       0.74      0.75      0.74        95  
  32.         4.0       0.84      0.73      0.78       111  
  33.         5.0       0.80      0.75      0.78        85  
  34.         6.0       0.88      0.92      0.90       103  
  35.         7.0       0.85      0.88      0.86       101  
  36.         8.0       0.85      0.83      0.84        88  
  37.         9.0       0.85      0.82      0.83       101  
  38.   
  39. avg / total       0.85      0.85      0.85      1000  
  40.   
  41. count: 1000 ok: 846  
  42. start:  26.90291 KNN 2016-06-30 11:06:24.930844  
  43. end:  27.276969 KNN 2016-06-30 11:06:25.304922  
  44. KNN  use time: 0.374 s  
  45.              precision    recall  f1-score   support  
  46.   
  47.         0.0       0.87      0.95      0.91        92  
  48.         1.0       0.98      0.99      0.98       127  
  49.         2.0       0.81      0.81      0.81        97  
  50.         3.0       0.71      0.81      0.75        95  
  51.         4.0       0.85      0.75      0.79       111  
  52.         5.0       0.84      0.78      0.80        85  
  53.         6.0       0.91      0.92      0.92       103  
  54.         7.0       0.88      0.87      0.88       101  
  55.         8.0       0.90      0.81      0.85        88  
  56.         9.0       0.81      0.84      0.83       101  
  57.   
  58. avg / total       0.86      0.86      0.86      1000  
  59.   
  60. count: 1000 ok: 857  
  61. start:  27.368384 DecisionTree 2016-06-30 11:06:25.396350  
  62. end:  27.770018 DecisionTree 2016-06-30 11:06:25.798082  
  63. DecisionTree  use time: 0.402 s  
  64.              precision    recall  f1-score   support  
  65.   
  66.         0.0       0.88      0.92      0.90        92  
  67.         1.0       0.98      0.97      0.97       127  
  68.         2.0       0.77      0.75      0.76        97  
  69.         3.0       0.77      0.74      0.75        95  
  70.         4.0       0.75      0.69      0.72       111  
  71.         5.0       0.75      0.71      0.73        85  
  72.         6.0       0.86      0.90      0.88       103  
  73.         7.0       0.71      0.78      0.75       101  
  74.         8.0       0.80      0.78      0.79        88  
  75.         9.0       0.83      0.85      0.84       101  
  76.   
  77. avg / total       0.81      0.81      0.81      1000  
  78.   
  79. count: 1000 ok: 815  
  80. start:  27.772331 LogisticRegression 2016-06-30 11:06:25.800391  
  81. end:  30.191443 LogisticRegression 2016-06-30 11:06:28.219996  
  82. LogisticRegression  use time: 2.419 s  
  83.              precision    recall  f1-score   support  
  84.   
  85.         0.0       0.89      0.95      0.92        92  
  86.         1.0       0.94      0.94      0.94       127  
  87.         2.0       0.76      0.73      0.75        97  
  88.         3.0       0.67      0.73      0.70        95  
  89.         4.0       0.78      0.68      0.73       111  
  90.         5.0       0.80      0.80      0.80        85  
  91.         6.0       0.91      0.89      0.90       103  
  92.         7.0       0.82      0.86      0.84       101  
  93.         8.0       0.80      0.83      0.82        88  
  94.         9.0       0.80      0.79      0.80       101  
  95.   
  96. avg / total       0.82      0.82      0.82      1000  
  97.   
  98. count: 1000 ok: 822  
  99. start:  30.19377 bayes 2016-06-30 11:06:28.222098  
  100. end:  30.209151 bayes 2016-06-30 11:06:28.237481  
  101. bayes  use time: 0.015 s  
  102.              precision    recall  f1-score   support  
  103.   
  104.         0.0       0.88      0.92      0.90        92  
  105.         1.0       0.84      0.97      0.90       127  
  106.         2.0       0.73      0.59      0.65        97  
  107.         3.0       0.61      0.71      0.65        95  
  108.         4.0       0.83      0.64      0.72       111  
  109.         5.0       0.63      0.67      0.65        85  
  110.         6.0       0.84      0.91      0.87       103  
  111.         7.0       0.78      0.74      0.76       101  
  112.         8.0       0.77      0.82      0.80        88  
  113.         9.0       0.80      0.72      0.76       101  
  114.   
  115. avg / total       0.78      0.77      0.77      1000  
  116.   
  117. count: 1000 ok: 774  

  • 3
    点赞
  • 31
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值