In[1]:
import random
import numpy as np
from cs231n.data_utils import load_CIFAR10
import matplotlib.pyplot as plt
from __future__ import print_function
#%matplotlib inline
# figsize设置图形大小,宽10.0,高8.0, interpolation是图像内插 cmap是分配颜色
plt.rcParams['figure.figsize'] = (10.0, 8.0)
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'
%load_ext autoreload
%autoreload 2
# %load_ext autoreload在执行用户代码前,重新装入 软件的扩展和模块
# autoreload 意思是自动重新装入,0:不执行装入命令 1:只装入%aimport要装入的模块 2:装入所有aimport不包含的模块
载入数据:
In[2]:
from cs231n.features import color_histogram_hsv, hog_feature
def get_CIFAR10_data(num_training=49000, num_validation=1000, num_test=1000):
cifar10_dir = 'cs231n/datasets/cifar-10-batches-py'
X_train, y_train, X_test, y_test = load_CIFAR10(cifar10_dir)
# 再次筛选
mask = list(range(num_training, num_training + num_validation))
X_val = X_train[mask]
y_val = y_train[mask]
mask = list(range(num_training))
X_train = X_train[mask]
y_train = y_train[mask]
mask = list(range(num_test))
X_test = X_test[mask]
y_test = y_test[mask]
return X_train, y_train, X_val, y_val, X_test, y_test
#此时变量已经返回,清理变量防止多次加载占用内存
try:
del X_train, y_train
del X_test, y_test
print('Clear previously loaded data.')
except:
pass
X_train, y_train, X_val, y_val, X_test, y_test = get_CIFAR10_data()
print('Training data shape: ', X_train.shape)
print('Training labels shape: ', y_train.shape)
print('Test data shape: ', X_test.shape)
print('Test labels shape: ', y_test.shape)
In[3]:
# 类(labels)
classes = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
num_classes = len(classes)
#每个类别采样数
samples_per_class = 7
# cls代指类
#enumerate 枚举 (0,plane 1,car 2,bird……)
for y, cls in enumerate(classes):
#找出矩阵中非零元素y_train=y的位置
idxs = np.flatnonzero(y_train == y)
# 在idxs中选出samples_per_class个样本,replace:false表示不能取相同数字
idxs = np.random.choice(idxs, samples_per_class, replace=False)
# 随机从对所选的样本的位置和样本所对应的图片在训练集中的位置进行循环
for i, idx in enumerate(idxs):
plt_idx = i * num_classes + y + 1 #分别对应的类
plt.subplot(samples_per_class, num_classes, plt_idx) #参数1代表行数、参数2代表列数、参数3代表第几个图,之所以每次都需要输入第1、2个参数,这两个参数是可变的
plt.imshow(X_train[idx].astype('uint8'))#画图, plt.imshow(a)中a的格式要求是width*height*depth,数据类型是无符号整型(uint8),由上一个函数指定宽高深
plt.axis('off') #关闭坐标轴显示
if i == 0:
plt.title(cls) #写上类别名
plt.show() #显示
In[4]:
子样例测试
num_training = 5000
mask = list(range(num_training))
X_train = X_train[mask]
y_train = y_train[mask]
num_test = 500
mask = list(range(num_test))
X_test = X_test[mask]
y_test = y_test[mask]
In[5]: