【深度之眼cs231n第七期】笔记(四)


对于k最邻近算法(KNN):

  1. 训练时,分类器记住所有的训练数据;
  2. 测试时,每一个测试图像都要和所有的训练图像计算距离,然后选取距离最近的k个图像,最后选取k个图像中出现次数最多的类标签作为输出(预测标签)。

准备工作

为jupyter notebook运行一些准备代码

# 使python2.x也能使用print()
from __future__ import print_function
# 之后需要随机选7张图片
import random
import numpy as np
# 导入数据集
from cs231n.data_utils import load_CIFAR10
# 为画图做准备
import matplotlib.pyplot as plt

# 这是使matplotlib图像出现jupyter notebook里,而不是出现在新窗口的一个小技巧
%matplotlib inline
# 设置画图的默认大小
plt.rcParams['figure.figsize'] = (10.0, 8.0) 
# 最近邻差插值: 像素为正方形
plt.rcParams['image.interpolation'] = 'nearest'
# 使用灰度输出而不是彩色输出
plt.rcParams['image.cmap'] = 'gray'

# 在执行用户代码前,重新装入软件的扩展和模块。autoreload意思是自动重新装入。无参:装入所有模块。
# 修改完.py文件后不需要重新从头运行,只需要重新装载修改过的函数即可
%load_ext autoreload
%autoreload 2

加载CIFAR-10数据

cifar10_dir = 'cs231n/datasets/cifar-10-batches-py'
# 清除变量,以防加载多次(可能导致内存问题)
try:
   del X_train, y_train
   del X_test, y_test
   print('Clear previously loaded data.')
except:
   pass

# 加载数据
X_train, y_train, X_test, y_test = load_CIFAR10(cifar10_dir)
# 输出训练数据和测试数据的大小
print('Training data shape: ', X_train.shape)
# Training data shape:  (50000, 32, 32, 3)
print('Training labels shape: ', y_train.shape)
# Training labels shape:  (50000,)
print('Test data shape: ', X_test.shape)
# Test data shape:  (10000, 32, 32, 3)
print('Test labels shape: ', y_test.shape)
# Test labels shape:  (10000,)

每个类可视化一些图像

# 类标签
classes = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
# 类的个数(10)
num_classes = len(classes)
# 每个类的样例
samples_per_class = 7
# y是索引值(0-9),cls是名字('plane'等)
for y, cls in enumerate(classes):
    # 在训练数据标签y_train中记录标签为y的下标(y_train中存放的是数字)
    idxs = np.flatnonzero(y_train == y)
    # 同一标签的图片中随机选7张(记录的是下标)
    idxs = np.random.choice(idxs, samples_per_class, replace=False)
    # i代表该类选出来的第i张图片,idx是该图片在训练数据集里的下标
    for i, idx in enumerate(idxs):
        # 计算图片所在的位置
        plt_idx = i * num_classes + y + 1
        # 总共是7行10列的图片
        plt.subplot(samples_per_class, num_classes, plt_idx)
        plt.imshow(X_train[idx].astype('uint8'))
        # 不显示坐标轴
        plt.axis('off')
        if i == 0:
            # 第一行图片上输出类别标签
            plt.title(cls)
plt.show()

可视化结果:
在这里插入图片描述
为了更高效地执行代码,我们只取样部分数据。选取5000张测试图片,500张测试图片。

num_training = 5000
mask = list
  • 2
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值