import gzip, os, sys
import numpy as np
import pylab
from scipy.stats import multivariate_normal
from urllib.request import urlretrieve
import matplotlib.pyplot as plt
def download(filename, source='http://yann.lecun.com/exdb/mnist/'):
print('Downloading %s' % filename)
urlretrieve(source+filename, 'mnist/' + filename)
def load_mnist_images(filename):
if not os.path.exists('mnist/' + filename):
download(filename)
with gzip.open('mnist/' + filename, 'rb') as f:
data = np.frombuffer(f.read(), np.uint8, offset=16)
data = data.reshape(-1, 784)
return data
def load_mnist_labels(filename):
if not os.path.exists('mnist/' + filename):
download(filename)
with gzip.open('mnist/' + filename, 'rb') as f:
data = np.frombuffer(f.read(), np.uint8, offset=8)
# data = data.reshape(-1, 784)
return data
train_data = load_mnist_images('train-images-idx3-ubyte.gz')
train_labels = load_mnist_labels('train-labels-idx1-ubyte.gz')
test_data = load_mnist_images('t10k-images-idx3-ubyte.gz')
test_labels = load_mnist_labels('t10k-labels-idx1-ubyte.gz')
print(train_data.shape)
print(test_data.shape)
# 可视化
def show_digit(x, label):
pylab.axis('off')
pylab.imshow(x.reshape((28, 28)), cmap=pylab.cm.gray)
pylab.title('Label' + str(label))
pylab.figure(figsize=(10, 8))
for i in range(25):
pylab.subplot(5, 5, i+1)
show_digit(test_data[i], test_labels[i])
pylab.tight_layout()
pylab.show()
import time
from sklearn.neighbors import BallTree
t_before = time.time()
ball_tree = BallTree(train_data)
t_after = time.time()
t_training = t_after - t_before
print('Time to build data structure seconds', t_training)
t_before = time.time()
test_neighbors = np.squeeze(ball_tree.query(test_data,
k=1,return_distance=False))
test_predictions = train_labels[test_neighbors]
t_after = time.time()
t_testing = t_after - t_before
print('time to classify test set(seconds)', t_testing)
t_accuracy = sum(test_predictions==test_labels) / float(len(test_labels))
print(t_accuracy)
import pandas as pd
import seaborn as sn
from sklearn import metrics
cm = metrics.confusion_matrix(test_labels, test_predictions)
df_cm = pd.DataFrame(cm, range(10), range(10))
sn.set(font_scale=1.2)
sn.heatmap(df_cm, annot=True, annot_kws={'size':16}, fmt='g')
t_error = sum(test_predictions != test_labels) / float(len(test_labels))
print(t_error)
# from collections import Counter
#
# ks = list(range(11))
# t_trainings = [0]
# t_testings = [0]
# t_errors = [0]
# for k in ks[1:]:
# test_neighbors = np.squeeze(ball_tree.query(test_data, k=k, return_distance=False))
# test_predictions = Counter([train_labels[test_neighbor] for test_neighbor in test_neighbors]).most_common(1)[0][0]
# t_error = sum(test_predictions != test_labels) / float(len(test_labels))
# print(k, t_error)
# t_errors.append(t_error)
#
# plt.plot(ks[1:], t_errors[1:])
# plt.xlabel('k')
# plt.xlabel('k')
# plt.ylabel('error')
plt.show()
