对于k最邻近算法(KNN):
- 训练时,分类器记住所有的训练数据;
- 测试时,每一个测试图像都要和所有的训练图像计算距离,然后选取距离最近的k个图像,最后选取k个图像中出现次数最多的类标签作为输出(预测标签)。
准备工作
为jupyter notebook运行一些准备代码
# 使python2.x也能使用print()
from __future__ import print_function
# 之后需要随机选7张图片
import random
import numpy as np
# 导入数据集
from cs231n.data_utils import load_CIFAR10
# 为画图做准备
import matplotlib.pyplot as plt
# 这是使matplotlib图像出现jupyter notebook里,而不是出现在新窗口的一个小技巧
%matplotlib inline
# 设置画图的默认大小
plt.rcParams['figure.figsize'] = (10.0, 8.0)
# 最近邻差插值: 像素为正方形
plt.rcParams['image.interpolation'] = 'nearest'
# 使用灰度输出而不是彩色输出
plt.rcParams['image.cmap'] = 'gray'
# 在执行用户代码前,重新装入软件的扩展和模块。autoreload意思是自动重新装入。无参:装入所有模块。
# 修改完.py文件后不需要重新从头运行,只需要重新装载修改过的函数即可
%load_ext autoreload
%autoreload 2
加载CIFAR-10数据
cifar10_dir = 'cs231n/datasets/cifar-10-batches-py'
# 清除变量,以防加载多次(可能导致内存问题)
try:
del X_train, y_train
del X_test, y_test
print('Clear previously loaded data.')
except:
pass
# 加载数据
X_train, y_train, X_test, y_test = load_CIFAR10(cifar10_dir)
# 输出训练数据和测试数据的大小
print('Training data shape: ', X_train.shape)
# Training data shape: (50000, 32, 32, 3)
print('Training labels shape: ', y_train.shape)
# Training labels shape: (50000,)
print('Test data shape: ', X_test.shape)
# Test data shape: (10000, 32, 32, 3)
print('Test labels shape: ', y_test.shape)
# Test labels shape: (10000,)
每个类可视化一些图像
# 类标签
classes = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
# 类的个数(10)
num_classes = len(classes)
# 每个类的样例
samples_per_class = 7
# y是索引值(0-9),cls是名字('plane'等)
for y, cls in enumerate(classes):
# 在训练数据标签y_train中记录标签为y的下标(y_train中存放的是数字)
idxs = np.flatnonzero(y_train == y)
# 同一标签的图片中随机选7张(记录的是下标)
idxs = np.random.choice(idxs, samples_per_class, replace=False)
# i代表该类选出来的第i张图片,idx是该图片在训练数据集里的下标
for i, idx in enumerate(idxs):
# 计算图片所在的位置
plt_idx = i * num_classes + y + 1
# 总共是7行10列的图片
plt.subplot(samples_per_class, num_classes, plt_idx)
plt.imshow(X_train[idx].astype('uint8'))
# 不显示坐标轴
plt.axis('off')
if i == 0:
# 第一行图片上输出类别标签
plt.title(cls)
plt.show()
可视化结果:
为了更高效地执行代码,我们只取样部分数据。选取5000张测试图片,500张测试图片。
num_training = 5000
mask = list