使用scikit-learn进行Python中的MNIST手写数字识别
去发现同类优质开源项目:https://gitcode.com/
该项目是一个引人入胜的机器学习教程,旨在解决广泛认知的MNIST手写数字分类问题。这里我们采用的是支持向量机(SVM),一种强大的监督学习算法,以原始像素特征作为输入。整个解决方案完全使用Python编写,并依赖于易于使用的机器学习库——scikit-learn。
这个项目的初衷并非追求最先进的性能,而是教你如何利用scikit-learn训练SVM图像分类器。尽管解决方案没有针对高精度优化,但测试结果仍然相当出色(见下文表格)。
如果你想要达到顶峰表现,可以参考以下两个资源:
以下是部分方法与模型的准确性对比表:
| 方法 | 准确率 | 备注 | |-------------------------------------------|----------|------------------| | 随机森林 | 0.937 | | | 简单的一层神经网络 | 0.926 | | | 简单的两层卷积神经网络 | 0.981 | | | SVM RBF | 0.9852 | C=5, gamma=0.05 | | 线性SVM + 尼斯特罗姆核近似 | | | | 线性SVM + 四ier核近似 | | |
项目设置
本教程在Ubuntu 18.10上编写并测试。项目包含了所有必要的Python包和虚拟环境管理工具Pipenv:
- 安装Python 3.6或以上版本。
- 安装Pipenv。
- 使用Git克隆仓库。
- 运行
pipenv install
安装所有必需的Python包。
git clone https://github.com/ksopyla/svm_mnist_digit_classification.git
cd svm_mnist_digit_classification
pipenv install
解决方案
在这个教程中,我们将采用两种不同的SVM学习方法。首先,我们使用经典的RBF核SVM,虽然训练大型数据集时可能比较耗时,但准确度较高。其次,我们会运用线性SVM,它允许O(n)时间复杂度的训练,以实现更快的速度。为了提高准确性,我们还将应用一些技巧,如核近似。
项目包括三个文件:
- mnist_helpers.py - 包含一些可视化函数,如MNIST数字的可视化和混淆矩阵。
- svm_mnist_classification.py - SVM RBF核分类脚本。
- svm_mnist_embeddings.py - 线性SVM与嵌入式方法脚本。
SVM with RBF核
svm_mnist_classification.py 脚本下载MNIST数据库并显示随机数字,然后标准化数据(均值为0,标准差为1),最后通过网格搜索交叉验证找到最佳参数。
线性SVM与不同嵌入方式
线性SVM的优势在于存在许多训练时间为O(n)的算法,相比于其他非线性SVM(大多数为O(n^2)),它们在大数据上训练特别快。
进一步改进
- 利用人工样本扩展训练集
- 应用随机化参数搜索
参考资料
项目文档提供了丰富的SVM和MNIST学习材料链接,供进一步深入研究。
通过参与这个开源项目,你可以熟悉SVM的工作原理,理解如何处理图像数据,并掌握如何优化模型参数。不论你是初学者还是经验丰富的开发者,都能从这个项目中获益匪浅。现在就加入进来,开启你的MNIST手写数字识别之旅吧!
去发现同类优质开源项目:https://gitcode.com/