kd树识别压缩有的mnist数据集

《一般knn算法识别mnist数据集(代码)》 中用一般的knn方法做了mnist识别。和神经网络方法比起来,knn慢很多,识别1000张图片需要234s。kd树更高效率实现knn的方法,它用二叉树来存储训练集中的样本,搜索k个近邻点时速度更快。具体算法如下:
这里写图片描述
这里写图片描述

对mnist,每张图片是28*28的,且灰度值大于3的就算是手写的痕迹了,因此,在一般knn算法识别mnist数据集(代码)中,通过二值化(one hot)灰度值的方式,效果还不错。但是,由于kd树通过各维的中值来切分区域,如果每一维的数值只有0和1,根本没办法切分。因此,将二值化后的矩阵分别按行求和,按列求和,得到新的56维的特征。相当于做了一个降维,同时,增大各维的取值范围。但是,这样处理后,图片原本良好的平移不变形变差了。识别率估计会降低。无妨,这里重要要的是效率的对比。
代码如下,其中,kd树的建立和搜索仿照了http://www.hankcs.com/ml/k-nearest-neighbor-method.html 中的代码。
代码如下:

#coding:utf-8

import numpy as np
import gzip
import itertools
from datetime import datetime

def _read32(bytestream):
    dt = np.dtype(np.uint32).newbyteorder('>')
    return np.frombuffer(bytestream.read(4), dtype=dt)[0]

def extract_images(input_file, is_value_binary, is_matrix):
    with gzip.open(input_file, 'rb') as zipf:
        magic = _read32(zipf)
        if magic !=2051:
            raise ValueError('Invalid magic number %d in MNIST image file: %s' %(magic, input_file.name))
        num_images = _read32(zipf)
        rows = _read32(zipf)
        cols = _read32(zipf)
        print magic, num_images, rows, cols
        buf = zipf.read(rows * cols * num_images)
        data = np.frombuffer(buf, dtype=np.uint8)
        #reshape成二维
        data = data.reshape(num_images, rows, cols)
        #二值化
        data_value_binary = np.minimum(data, 1)
        #按行相加,存到钱28个元素中,按列相加,存入后28个元素中
        #如果分类效果不好,可再计算按对角线相加、行列式等
        #多加了一列,以便train_x存储标签用。
        data_tidy = np.zeros((num_images, rows + cols + 1), dtype=np.uint32)
        for i in range(num_images):
            data_tidy[i, :rows] = np.sum(data_value_binary[i], axis=1)
            data_tidy[i, rows:(rows+cols)] = (np.sum(data_value_binary[i].transpose(), axis=1))
        return data_tidy


#抽取标签
#仿照tensorflow中mnist.py写的
def extract_labels(input_file):
    with gzip.open(input_file, 'rb') as zipf:
        magic = _read32(zipf)
        if magic != 2049:
            raise ValueError('Invalid magic number %d in MNIST label file: %s' % (magic, input_file.name))
        num_items = _read32(zipf)
        buf = zipf.read(num_items)
        labels = np.frombuffer(buf, dtype=np.uint8)
        return labels


class node:
    def __init__(self, point, label):
        self.left = None
        self.right = None
        self.point = point
        self.label = label  #由于按树存储的时候数据点顺序打乱了,这里将label也存进树里面。
        self.parent = None
        pass

    def set_left(self, left):
        if left == None: pass
        left.parent = self
        self.left = left

    def set_right(self, right):
        if right == None: pass
        right.parent = self
        self.right = right


def median(lst):
    m = len(lst) / 2
    return lst[m], m


def build_kdtree(data, d):
    data = sorted(data, key=lambda x: x[d.next()]) 
    p, m = median(data)
    tree = node(p[:-1], p[-1])

    del data[m]

    #递归查询新节点该存放的位置,同时也递归的切分区域
    if m > 0: tree.set_left(build_kdtree(data[:m], d))
    if len(data) > 1: tree.set_right(build_kdtree(data[m:], d))
    return tree

#计算距离
def distance(a, b):
    diff = a - b
    squaredDiff = diff ** 2
    return np.sum(squaredDiff)


def search_kdtree(tree, d, target, k):
    den = d.next()
    #直到搜索到不存在更近的节点才停止。
    if target[den] < tree.point[den]:
        if tree.left != None:
            return search_kdtree(tree.left, d, target, k)
    else:
        if tree.right != None:
            return search_kdtree(tree.right, d, target, k)

    #持续更新距离最近的k个节点
    def update_best(t, best):
        if t == None: return
        label = t.label
        t = t.point
        d = distance(t, target)
        for i in range(k):
            if d < best[i][1]:
                for j in range(0, i):
                    best[j][1] = best[j+1][1]
                    best[j][0] = best[j+1][0]
                    best[j][2] = best[j+1][2]
                best[i][1] = d
                best[i][0] = t
                best[i][2] = label
    best = []
    for i in range(k):
        best.append( [tree.point, 100000.0, 10] )
    while (tree.parent != None):
        update_best(tree.parent.left, best)
        update_best(tree.parent.right, best)
        tree = tree.parent
    return best


def testHandWritingClass():
    ## step 1: load data
    print "step 1: load data..."
    train_x = extract_images('data/mnist/train_images', True, True)
    train_y = extract_labels('data/mnist/train_labels')
    test_x = extract_images('data/mnist/test_images', True, True)
    test_y = extract_labels('data/mnist/test_labels')

    l = min(train_x.shape[0], train_y.shape[0])
    rows = train_x.shape[1]
    #将训练集的标签存到train_x中,一遍一同存储到kd树中。
    for i in range(l):
        train_x[i, -1] = train_y[i]

    densim = itertools.cycle(range(0, rows-1))
    ## step 2: training...
    print "step 2: build tree..."
    mnist_tree = build_kdtree(train_x, densim)

    ## step 3: testing
    print "step 3: testing..."
    a = datetime.now()
    numTestSamples = test_x.shape[0]
    matchCount = 0
    test_num = numTestSamples
    K = 3
    for i in xrange(test_num):
        best_k = search_kdtree(mnist_tree, densim, test_x[i, :-1], K)
        #计算数量最大的label。
        classCount = {}
        for j in range(K):
            voteLabel = best_k[j][2]
            classCount[voteLabel] = classCount.get(voteLabel, 0) + 1
        maxCount = 0
        predict = 0
        for key, value in classCount.items():
            if value > maxCount:
                maxCount = value
                predict = key
        if predict == test_y[i]:
            matchCount += 1
        if i % 100 == 0:
            print "完成%d张图片"%(i)
    accuracy = float(matchCount) / test_num
    b = datetime.now()
    print "一共运行了%d秒"%((b-a).seconds)

    ## step 4: show the result
    print "step 4: show the result..."
    print 'The classify accuracy is: %.2f%%' % (accuracy * 100)

if __name__ == '__main__':
    testHandWritingClass()

同样,k=3, 识别10000张的结果如下:
这里写图片描述

速度快了585倍,这其中每张图片从784维变成了56维,速度有加快,但kd树本身比knn快也是确定无疑的。
由于数据的压缩,准确率下降了好多。

  • 1
    点赞
  • 17
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值