CS231n-assignment1-KNN

In[1]: 

import random
import numpy as np
from cs231n.data_utils import load_CIFAR10
import matplotlib.pyplot as plt
from __future__ import print_function
#%matplotlib inline

# figsize设置图形大小,宽10.0,高8.0, interpolation是图像内插 cmap是分配颜色
plt.rcParams['figure.figsize'] = (10.0, 8.0)
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'
%load_ext autoreload
%autoreload 2
# %load_ext autoreload在执行用户代码前,重新装入 软件的扩展和模块
# autoreload 意思是自动重新装入,0:不执行装入命令 1:只装入%aimport要装入的模块 2:装入所有aimport不包含的模块

载入数据:

In[2]: 

from cs231n.features import color_histogram_hsv, hog_feature

def get_CIFAR10_data(num_training=49000, num_validation=1000, num_test=1000):
    cifar10_dir = 'cs231n/datasets/cifar-10-batches-py'
    X_train, y_train, X_test, y_test = load_CIFAR10(cifar10_dir)
    # 再次筛选
    mask = list(range(num_training, num_training + num_validation))
    X_val = X_train[mask]
    y_val = y_train[mask]
    mask = list(range(num_training))
    X_train = X_train[mask]
    y_train = y_train[mask]
    mask = list(range(num_test))
    X_test = X_test[mask]
    y_test = y_test[mask]
    
    return X_train, y_train, X_val, y_val, X_test, y_test

#此时变量已经返回,清理变量防止多次加载占用内存
try:
   del X_train, y_train
   del X_test, y_test
   print('Clear previously loaded data.')
except:
   pass

X_train, y_train, X_val, y_val, X_test, y_test = get_CIFAR10_data()

print('Training data shape: ', X_train.shape)
print('Training labels shape: ', y_train.shape)
print('Test data shape: ', X_test.shape)
print('Test labels shape: ', y_test.shape)

In[3]:

# 类(labels)
classes = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
num_classes = len(classes)
#每个类别采样数
samples_per_class = 7
# cls代指类
#enumerate 枚举 (0,plane 1,car 2,bird……) 
for y, cls in enumerate(classes):
    #找出矩阵中非零元素y_train=y的位置
idxs = np.flatnonzero(y_train == y)
# 在idxs中选出samples_per_class个样本,replace:false表示不能取相同数字
    idxs = np.random.choice(idxs, samples_per_class, replace=False)
# 随机从对所选的样本的位置和样本所对应的图片在训练集中的位置进行循环
for i, idx in enumerate(idxs):
        plt_idx = i * num_classes + y + 1 #分别对应的类
        plt.subplot(samples_per_class, num_classes, plt_idx) #参数1代表行数、参数2代表列数、参数3代表第几个图,之所以每次都需要输入第1、2个参数,这两个参数是可变的
        plt.imshow(X_train[idx].astype('uint8'))#画图, plt.imshow(a)中a的格式要求是width*height*depth,数据类型是无符号整型(uint8),由上一个函数指定宽高深
        plt.axis('off') #关闭坐标轴显示
        if i == 0:
            plt.title(cls) #写上类别名
plt.show() #显示

In[4]:

子样例测试

num_training = 5000
mask = list(range(num_training))
X_train = X_train[mask]
y_train = y_train[mask]

num_test = 500
mask = list(range(num_test))
X_test = X_test[mask]
y_test = y_test[mask]

In[5]:

评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Esaka7

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值