[cs231n]Assignment1_Knn 代码学习

部分资料来源于网络,仅做个人学习之用

目录

1. Download the CIFAR10 datasets, and load it 

2. Define a K Nearest Neighbor Class

3. Train and Test

4. Cross Validation


1. Download the CIFAR10 datasets, and load it 

Setup code

import random
import numpy as np
from cs231n.data_utils import load_CIFAR10
import matplotlib.pyplot as plt

#这是使matplotlib图形内联出现在笔记本中而不是在一个新的窗口 的一个小技巧
%matplotlib inline
plt.rcParams['figure.figsize'] = (10.0, 8.0) # 设置显示图像的最大范围
plt.rcParams['image.interpolation'] = 'nearest' #设置插值的方式:最邻近差值
plt.rcParams['image.cmap'] = 'gray' # 灰度空间 0-255

%load_ext autoreload
%autoreload 2
""" 在执行用户代码前,重新装入软件的扩展和模块。
 autoreload 意思是自动重新装入。它后面可带参数。参数意思你要查你自己的版本帮助文件。
一般说:
无参:装入所有模块。
0:不执行 装入命令。
1: 只装入所有 %aimport 要装模块
2:装入所有 %aimport 不包含的模块。"""

Load the CIFAR10 data

cifar10_dir = 'cs231n/datasets/cifar-10-batches-py'
X_train, y_train, X_test, y_test = load_CIFAR10(cifar10_dir) # 读取数据集

# 作为一个完整性检查,我们打印出训练和测试数据的大小。
print('Training data shape: ', X_train.shape)
print('Training labels shape: ', y_train.shape)
print('Test data shape: ', X_test.shape)
print('Test labels shape: ', y_test.shape)

Show some CIFAR10 images

classes = ['plane', 'car', 'bird', 'cat', 'dear', 'dog', 'frog', 'horse', 'ship', 'truck']
num_classes = len(classes) # 一共有10类
num_each_class = 7 # 每类选7个

"""
enumerate() 函数用于将一个可遍历的数据对象(如列表、元组或字符串)组合为一个索引序列,同时列出数据和数据下标,一般用在 for 循环当中。
即enumerate的里面是一个个的pair,第一维是下标,第二维是每一个值。 
y是pair的第一维也就是种类的下标,y_train是训练集里的每一个种类的值 因此就相当于把所有这个种类的抠出来组成一个下标的list

    >>>seasons = ['Spring', 'Summer', 'Fall', 'Winter']
    >>> list(enumerate(seasons))
    [(0, 'Spring'), (1, 'Summer'), (2, 'Fall'), (3, 'Winter')]
"""

for y, cls in enumerate(classes):
    idxs = np.flatnonzero(y_train == y)
"""
np.flatnonzero() 输入一个矩阵,返回了其中非零元素的位置.
np.flatnonzero(y_train == y):在作业中给出的用法:不走寻常路,用来返回某个特定元素的位置找出标签中y类的位置
"""
    idxs = np.random.choice(idxs, num_each_class, replace=False)
"""在所有的这些下标中,随机抽取num_each_class个下标,从中选出我们所需的7个样本,然后这个7个元素不能相同(replace=False)
"""
    for i, idx in enumerate(idxs):
# 对所选的样本的位置和样本所对应的图片在训练集中的位置进行循环
        plt_idx = i * num_classes + (y + 1)  # 计算在子图中所占位置
        plt.subplot(num_each_class, num_classes, plt_idx)  # 说明要画的子图的编号
"""
matplotlib.pyplot.subplot(XXX):
该函数输入量为三个整数比如subplot(2,1,1)前两个数表示子图组成的矩阵的行列数,比如有6个子图,排列成3行2列,那就是subplot(3,2,X)。最后一个数表示要画第X个图了。

参数1代表行数、参数2代表列数、参数3代表第几个图,之所以每次都需要输入第1、2个参数,是因为这两个参数是可变的
"""
        plt.imshow(X_train[idx].astype('uint8'))  # 在上一条指令指定好绘制区域后,画图
        plt.axis('off')  # 不显示坐标尺寸
        if i == 0:
            plt.title(cls) # 写上标题,即类别名
plt.show()
  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值