Task03:基于机器学习cifar10分类分类预测
cifar10分类
import numpy as np
import platform
import pickle
import os
import matplotlib.pyplot as plt
import time
#加载数据
def load_pickle(f):
version = platform.python_version_tuple()
if version[0] == '2':
return pickle.load(f)
elif version[0] == '3':
return pickle.load(f, encoding='latin1')
raise ValueError("invalid python version: {}".format(version))
def loadCIFAR_batch(filename):
with open(filename, 'rb') as f:
datadict = load_pickle(f)
x = datadict['data']
y = datadict['labels']
x = x.reshape(10000, 3, 32, 32).transpose(0, 3, 2, 1).astype('float')
y = np.array(y)
return x, y
def loadCIFAR10(root):
xs = []
ys = []
for b in range(1, 6):
f = os.path.join(root, 'data_batch_%d' % (b, ))
x, y = loadCIFAR_batch(f)
xs.append(x)
ys.append(y)
X = np.concatenate(xs)
Y = np.concatenate(ys)
x_test, y_test = loadCIFAR_batch(os.path.join(root, 'test_batch'))
return X, Y, x_test, y_test
# 将数据分成train和test,以及显示
def data_validation(x_train, y_train, x_test, y_test):
num_training = 49000
num_validation = 1000
num_test = 1000
num_dev = 500
mean_image = np.mean(x_train, axis=0