softmax方法分类CIFAR-10数据
Code:
# 导包
import numpy as np
import pickle
import matplotlib.pyplot as plt
import math
# 导入数据
train_data = []
train_label = []
for i in range(1,6):
file_object = open('work/CIFAR-10/data_batch_'+str(i),'rb')
data_object = pickle.load(file_object,encoding='bytes') # 字典格式
# print(data_object[b'data'])
for line in data_object[b'data']:
train_data.append(line)
for line in data_object[b'labels']:
train_label.append(line)
# notice there,train_data and train_label are the structure of python's list,you should transport to numpy's array
train_data = np.array(train_data).astype("float")
train_label = np.array(train_label)
print("train_data shape:"+str(train_data.shape))
print("train_label shape:"+str(train_label.shape))
#%%
test_data = []
test_label = []
test_file = open('work/CIFAR-10/test_batch','rb')
test_file_object = pickle.load(test_file,encoding='bytes')
# print(test_file_object)
for line in test_file_object[b'data']:
test_data.append(line)
for line in test_file_object[b'labels']:
test_label.append(line)
test_data = np.array(test_data).astype("float")
test_label = np.array(test_label)
print("test_data s