KNN模型MNIST数据集分类

import gzip, os, sys
import numpy as np
import pylab
from scipy.stats import multivariate_normal
from urllib.request import urlretrieve
import matplotlib.pyplot as plt

def download(filename, source='http://yann.lecun.com/exdb/mnist/'):
    print('Downloading %s' % filename)
    urlretrieve(source+filename, 'mnist/' + filename)

def load_mnist_images(filename):
    if not os.path.exists('mnist/' + filename):
        download(filename)
    with gzip.open('mnist/' + filename, 'rb') as f:
        data = np.frombuffer(f.read(), np.uint8, offset=16)
    data = data.reshape(-1, 784)
    return data

def load_mnist_labels(filename):
    if not os.path.exists('mnist/' + filename):
        download(filename)
    with gzip.open('mnist/' + filename, 'rb') as f:
        data = np.frombuffer(f.read(), np.uint8, offset=8)
    # data = data.reshape(-1, 784)
    return data

train_data = load_mnist_images('train-images-idx3-ubyte.gz')
train_labels = load_mnist_labels('train-labels-idx1-ubyte.gz')

test_data = load_mnist_images('t10k-images-idx3-ubyte.gz')
test_labels = load_mnist_labels('t10k-labels-idx1-ubyte.gz')

print(train_data.shape)
print(test_data.shape)

# 可视化
def show_digit(x, label):
    pylab.axis('off')
    pylab.imshow(x.reshape((28, 28)), cmap=pylab.cm.gray)
    pylab.title('Label' + str(label))

pylab.figure(figsize=(10, 8))
for i in range(25):
    pylab.subplot(5, 5, i+1)
    show_digit(test_data[i], test_labels[i])
pylab.tight_layout()
pylab.show()

import time
from sklearn.neighbors import BallTree

t_before = time.time()
ball_tree = BallTree(train_data)
t_after = time.time()

t_training = t_after - t_before
print('Time to build data structure seconds', t_training)

t_before = time.time()
test_neighbors = np.squeeze(ball_tree.query(test_data,
                                             k=1,return_distance=False))
test_predictions = train_labels[test_neighbors]
t_after = time.time()

t_testing = t_after - t_before
print('time to classify test set(seconds)', t_testing)

t_accuracy = sum(test_predictions==test_labels) / float(len(test_labels))
print(t_accuracy)

import pandas as pd
import seaborn as sn
from sklearn import metrics

cm = metrics.confusion_matrix(test_labels, test_predictions)
df_cm = pd.DataFrame(cm, range(10), range(10))
sn.set(font_scale=1.2)

sn.heatmap(df_cm, annot=True, annot_kws={'size':16}, fmt='g')

t_error = sum(test_predictions != test_labels) / float(len(test_labels))
print(t_error)

# from collections import Counter
#
# ks = list(range(11))
# t_trainings = [0]
# t_testings = [0]
# t_errors = [0]
# for k in ks[1:]:
#     test_neighbors = np.squeeze(ball_tree.query(test_data, k=k, return_distance=False))
#     test_predictions = Counter([train_labels[test_neighbor] for test_neighbor in test_neighbors]).most_common(1)[0][0]
#     t_error = sum(test_predictions != test_labels) / float(len(test_labels))
#     print(k, t_error)
#     t_errors.append(t_error)
#
# plt.plot(ks[1:], t_errors[1:])
# plt.xlabel('k')
# plt.xlabel('k')
# plt.ylabel('error')
plt.show()

在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值