Classification

MNIST

这是一个包含了70000个手写阿拉伯数字的图片样例,

GetTheData

ML教材上的代码样例会报错,可能是某个模块版本升级后的问题,查询github得到以下替代代码。

from shutil import copyfileobj
from six.moves import urllib
from sklearn.datasets.base import get_data_home
import os

def fetch_mnist(data_home=None):
    mnist_alternative_url = "https://github.com/amplab/datascience-sp14/raw/master/lab7/mldata/mnist-original.mat"
    data_home = get_data_home(data_home=data_home)
    data_home = os.path.join(data_home, 'mldata')
    if not os.path.exists(data_home):
        os.makedirs(data_home)
    mnist_save_path = os.path.join(data_home, "mnist-original.mat")
    if not os.path.exists(mnist_save_path):  # save at ~/scikit-learn/mldata/..
        mnist_url = urllib.request.urlopen(mnist_alternative_url)
        with open(mnist_save_path, "wb") as matlab_file:
            copyfileobj(mnist_url, matlab_file)

文件储存在~/scikit-learn/mldata/..目录下,调用fetch_mnist()获取数据。该数据以字典序降序排列(实际上是多组0-9组成),每行其实是28*28个像素点的灰度值,一共784列。以下代码将数据读入内存。

from sklearn.datasets import fetch_mldata
mnist = fetch_mldata('MNIST original')
X, y = mnist["data"], mnist["target"]
mnist

查看其中一个样例。

%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt

some_digit = X[36000]
some_digit_image = some_digit.reshape(28, 28)

plt.imshow(some_digit_image, cmap = matplotlib.cm.binary,
           interpolation="nearest")
plt.axis("on")
plt.show()

SplitTheSet

划分测试集和训练集。以前6w个为训练集,后1w个为测试集。

X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]

不希望在交叉验证时分fold缺少某个数字,有些算法对样例顺序比较敏感(如果训练连续类似的样例,则会影响准确性),故打乱训练集顺序。

import numpy as np

shuffle_index = np.random.permutation(60000)
X_train, y_train = X_train[shuffle_index], y_train[shuffle_index]

Binary Classifier

训练二分类算法,算法只回答true和false,以数字是不是5为例。

y_train_5 = (y_train == 5)  # True for all 5s, False for all other digits.
y_test_5 = (y_test == 5)

SGD

Stochastic Gradient Descent(SGD),就是独立地为所有样例打分,排成线性,去某一阈值,高之为真,低之为假。该算法也适用于在线学习。

from sklearn.linear_model import SGDClassifier

sgd_clf = SGDClassifier(random_state=42)
sgd_clf.fit(X_train, y_train_5)

sgd_clf.predict([some_digit])

Performance Measures

要制定标准来评定算

  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值