CS321n入门之KNN(1)——一个初学者的随学笔记
来源:吃树叶的土豆
视频中的第一个实例就是一个简单的KNN算法,作为小白特别是对py特性还不是很清楚的初学者。刷完视频之后再去看作业PPt的时候着实是一脸懵逼。通过几个小时的调试勉强弄明白是什么情况,同时也希望给同样是小白的初学者分享一些经验,以提高学习效率。
1、了解数据集
视频中实现KNN进行图像分类算法的数据集是:CIFAR-10。所以我们从这个数据集开始分析:
通过调试我们可以很直观的看到CIFAR-10数据集中的数据呈现状况如下:
1)batch_label: 用于标记数据集的类型,如:这里的训练集:training batch
2)label:标签,CIFAR-10是具有10种类型的图像数据库。这里的标签是指是数据集中的每张图片属于10种类型中的哪一类:“0~9”
3)data:数据,这里存的是数据集中所有图像的各个像素值。实际上是将nm二维的图像数据先转化成一位数组,由此data中的每一行则代表一张图像。图像大小为3232
4)filename:文件名,记录数据集中所有图像的文件名
如下图所示:
2、相关步骤
该算法实现的主要步骤可以分为以下:
1)数据处理:将数据集中的数据进行格式标准处理,这里通常会使用到numpy数据处理库。
2)模型建立:利用已封装好或自定义的模型进行训练,也就是class
3)评价指标:通过测试集的分类结果和测试集中已标记的label进行比较,求得该模型训练结果的有效性。由此来判断一个模型的好坏。
4)模型优化:调整超参数。在本算法中存在的超参数包括两个。其一:distance距离公式的选择,通常是曼哈顿距离和欧式距离;其二,k值的设定。
3、代码
import pickle as p
import matplotlib.pyplot as plt
import numpy as np
# NearestNeighbor class
class NearestNeighbor(object):
def __init__(self):
pass
def train(self, X, y):
""" X is N x D where each row is an example. Y is 1-dimension of size N """
# the nearest neighbor classifier simply remembers all the training data
self.Xtr = X
self.ytr = y
def predict(self, X):
""" X is N x D where each row is an example we wish to predict label for """
num_test = X.shape[0]#获取数据大小
print(num_test)
# lets make sure that the output type matches the input type
Ypred = np.zeros(num_test, dtype=self.ytr.dtype)
# loop over all test rows
for i in range(num_test):
# find the nearest training image to the i'th test image
# using the L1 distance (sum of absolute value differences)
distances = np.sum