图像分类笔记(上)链接 https://zhuanlan.zhihu.com/p/20894041?refer=intelligentunit
import numpy as np
from cs231n.data_utils import load_CIFAR10
class NearestNeighbor(object):
def __init__(self):
pass
def train(self, X, y):
""" X is N x D where each row is an example. Y is 1-dimension of size N """
# the nearest neighbor classifier simply remembers all the training data
self.Xtr = X
self.ytr = y
def predict(self, X):
""" X is N x D where each row is an example we wish to predict label for """
num_test = X.shape[0]
# lets make sure that the output type matches the input type
Ypred = np.zeros(num_test, dtype = self.ytr.dtype)
# loop over all test rows
for i in xrange(num_test):
# find the nearest training image to the i'th test image
# using the L1 distance (sum of absolute value differences)
distances = np.sum(np.abs(self.Xtr - X[i,:]), axis = 1)
min_index = np.argmin(distances) # get the index with smallest distance
Ypred[i] = self.ytr[min_index] # predict the label of the nearest example
return Ypred
Xtr, Ytr, Xte, Yte = load_CIFAR10('cs231n/data/cifar10/') # a magic function we provide
# flatten out all images to be one-dimensional
Xtr_rows = Xtr.reshape(Xtr.shape[0], 32 * 32 * 3) # Xtr_rows becomes 50000 x 3072
Xte_rows = Xte.reshape(Xte.shape[0], 32 * 32 * 3) # Xte_rows becomes 10000 x 3072
nn = NearestNeighbor() # create a Nearest Neighbor classifier class
nn.train(Xtr_rows, Ytr) # train the classifier on the training images and labels
Yte_predict = nn.predict(Xte_rows) # predict labels on the test images
# and now print the classification accuracy, which is the average number
# of examples that are correctly predicted (i.e. label matches)
print 'accuracy: %f' % ( np.mean(Yte_predict == Yte) )
编译环境:Python 2.7 依赖项:numpy scipy
numpy如果是pip安装的话会提示找不到mkl,解决办法:numpy和scipy都用whl文件安装。
64位系统报错memoryerror的解决办法:运行程序相关的内容不要放在C盘,关掉正在运行的其他软件。
Yte_predict = nn.predict(Xte_rows) # predict labels on the test images
这行语句我运行了两个多小时还没有出结果,因此我把数据量改小后验证程序是能够运行的。
按照官方笔记上来说,如果你用这段代码跑CIFAR-10,你会发现准确率能达到38.6%。
如果你能运行出来的话这个准确率是个参考。
cifar10下载:http://www.cs.toronto.edu/~kriz/cifar.html 选择python版。
代码相关的其他文件可以在这里下载