斯坦福CS231n课程作业——Nearest Neighbor Classifier

课程官网:http://cs231n.stanford.edu/
课程资料:http://cs231n.stanford.edu/syllabus.html
课程PDF:http://cs231n.stanford.edu/slides/2020/lecture_2.pdf
作业资料:https://cs231n.github.io/classification/#k—nearest-neighbor-classifier
CIFAR-10数据官网:http://www.cs.toronto.edu/~kriz/cifar.html

下面是代码,注意所加载的文件路径

import numpy as np
import os
import pickle
import time


class NearestNeighbor(object):
    def __init__(self):
        pass
    # 本网络没有什么训练,只是将所有训练数据都加载到内存中
    def train(self, X, y):
        # X为50000x3072的数组,y为50000x1的数组
        self.Xtr = X
        self.ytr = y

    def predict(self, X):
        # 得到测试数据的个数,本例中为10000
        num_test = X.shape[0]
        # 生成一个10000x10000的全0矩阵,用作存储测试产生的labels,元素的类型与训练数据中的labels相同,本例中应为int32
        Ypred = np.zeros(num_test, dtype=self.ytr.dtype)
        # 循环每个待测试的图片,共10000次
        for i in range(num_test):
            print("testing %d" % i)
            # 使用L1距离,计算所有训练图片到第i张测试图片的距离
            # np.abs为计算绝对值
            # np.sum(..., axis=1)为将结果按照第2个坐标轴的展开方向上求和
            # distances = np.sum(np.abs(self.Xtr - X[i,:]), axis=1)
            # 使用L2距离,计算所有训练图片到第i张测试图片的距离
            # np.square为数组中每个元素计算平方
            # np.sum(..., axis=1)为将结果按照第2个坐标轴的展开方向上求和
            # np.sqrt为求平方根
            distances = np.sqrt(np.sum(np.square(self.Xtr - X[i, :]), axis=1))
            # np.argmin用于找出distances中最小的元素(即与测试图片距离最近的训练图片)的index
            min_index = np.argmin(distances)
            Ypred[i] = self.ytr[min_index]
        return Ypred


def load_CIFAR10(path):
    xs = []
    ys = []
    # 循环,b依次为1,2,3,4,5
    for b in range(1,6):
        # os.path.join用于拼接文件路径
        f = os.path.join(path, 'data_batch_%d' % (b,))
        # 加载每一个data_batch_x文件
        X, Y = load_CIFAR_batch(f)
        # 加载5次,xs为50000个图片的数据,ys为50000个0-9的数字
        xs.append(X)
        ys.append(Y)
    # np.concatenate为拼接数组,???
    Xtr = np.concatenate(xs)
    Ytr = np.concatenate(ys)
    # 删除变量X和Y,???
    del X, Y
    # 加载test_batch文件
    Xte, Yte = load_CIFAR_batch(os.path.join(path, 'test_batch'))
    return Xtr, Ytr, Xte, Yte


def load_CIFAR_batch(filename):
    # with 可以在用完文件之后自动关闭(close)文件,'rb'为按二进制方式只读文件
    with open(filename, 'rb') as f:
        # pickle.load将目标反序列化为对象
        datadict = pickle.load(f, encoding='latin1')
        # batch文件反序列化后得到一个字典,包含'data'和'labels'两个key
        '''
        data是一个10000x3072的numpy数组,数组的每一行储存一个32x32的彩色图像。前1024个数字为红色(red),中间1024为绿色(green),最后1024为蓝色(blue)。
        labels是一个10000个0-9的数字的列表。
        '''
        X = datadict['data']
        Y = datadict['labels']
        # reshape用来改变数据的格式,transpose用来交换张量的不同轴,astype用来改变np.array中所有数据元素的数据类型。
        X = X.reshape(10000, 3, 32, 32).transpose(0, 2, 3, 1).astype("float")
        # 把Y的格式变为np.array
        Y = np.array(Y)
        return X, Y


def runNN():
    print("load data")
    # Xtr为训练数据中的图片数据,Ytr为训练数据中的标签数据
    # Xte为测试数据中的图片数据,Yte为测试数据中的标签数据
    Xtr, Ytr, Xte, Yte = load_CIFAR10('data/cifar-10-batches-py')
    # 把Xtr和Xte两个变量的格式变为50000x3072的数组
    Xtr_rows = Xtr.reshape(Xtr.shape[0], 32 * 32 * 3)
    Xte_rows = Xte.reshape(Xte.shape[0], 32 * 32 * 3)
    # 构造一个NearestNeighbor()类的对象nn
    nn = NearestNeighbor()
    print("start training")
    # 用训练数据对nn这个网络进行训练
    nn.train(Xtr_rows, Ytr)
    print("start testing")
    # 用训练好的网络nn进行测试
    Yte_predict = nn.predict(Xte_rows)
    # 统计测试的正确率
    print('accuracy: %f' % (np.mean(Yte_predict == Yte)))


if __name__ == '__main__':
    print ("开始时间: ", time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
    runNN()
    print ("结束时间: ", time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))

使用L1距离的打印结果

accuracy: 0.385900

使用L2距离的打印结果

accuracy: 0.353900
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值