CS231n课程笔记——Nearest Neighbor分类器 示例代码

图像分类笔记(上)链接 https://zhuanlan.zhihu.com/p/20894041?refer=intelligentunit


import numpy as np
from cs231n.data_utils import load_CIFAR10

class NearestNeighbor(object):
  def __init__(self):
    pass

  def train(self, X, y):
    """ X is N x D where each row is an example. Y is 1-dimension of size N """
    # the nearest neighbor classifier simply remembers all the training data
    self.Xtr = X
    self.ytr = y

  def predict(self, X):
    """ X is N x D where each row is an example we wish to predict label for """
    num_test = X.shape[0]
    # lets make sure that the output type matches the input type
    Ypred = np.zeros(num_test, dtype = self.ytr.dtype)

    # loop over all test rows
    for i in xrange(num_test):
      # find the nearest training image to the i'th test image
      # using the L1 distance (sum of absolute value differences)
      distances = np.sum(np.abs(self.Xtr - X[i,:]), axis = 1)
      min_index = np.argmin(distances) # get the index with smallest distance
      Ypred[i] = self.ytr[min_index] # predict the label of the nearest example

    return Ypred

Xtr, Ytr, Xte, Yte = load_CIFAR10('cs231n/data/cifar10/') # a magic function we provide
# flatten out all images to be one-dimensional
Xtr_rows = Xtr.reshape(Xtr.shape[0], 32 * 32 * 3) # Xtr_rows becomes 50000 x 3072
Xte_rows = Xte.reshape(Xte.shape[0], 32 * 32 * 3) # Xte_rows becomes 10000 x 3072

nn = NearestNeighbor() # create a Nearest Neighbor classifier class
nn.train(Xtr_rows, Ytr) # train the classifier on the training images and labels
Yte_predict = nn.predict(Xte_rows) # predict labels on the test images
# and now print the classification accuracy, which is the average number
# of examples that are correctly predicted (i.e. label matches)
print 'accuracy: %f' % ( np.mean(Yte_predict == Yte) )

编译环境:Python 2.7 依赖项:numpy scipy

numpy如果是pip安装的话会提示找不到mkl,解决办法:numpy和scipy都用whl文件安装。

64位系统报错memoryerror的解决办法:运行程序相关的内容不要放在C盘,关掉正在运行的其他软件。


Yte_predict = nn.predict(Xte_rows) # predict labels on the test images
这行语句我运行了两个多小时还没有出结果,因此我把数据量改小后验证程序是能够运行的。

按照官方笔记上来说,如果你用这段代码跑CIFAR-10,你会发现准确率能达到38.6%。

如果你能运行出来的话这个准确率是个参考。


cifar10下载:http://www.cs.toronto.edu/~kriz/cifar.html 选择python版。

代码相关的其他文件可以在这里下载

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值