#导包
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
加载训练数据
Python引入了with语句来自动帮我们调用close()方法
使用pickle.load(),encoding = ‘ISO-8859-1’
transpose([])方法调用
#定义打开文件函数
def unpickle(file):
import pickle
with open(file, 'rb') as fo:
dict = pickle.load(fo, encoding='ISO-8859-1')
return dict
显示其中部分数据
dict = unpickle("data_batch_1")
print(dict.keys())
dict_keys([‘batch_label’, ‘labels’, ‘data’, ‘filenames’])
dict2 = unpickle("test_batch")
print(dict2.keys())
dict_keys([‘batch_label’, ‘labels’, ‘data’, ‘filenames’])
#加载预测数据
X_train = []
y_train = []
for i in range(1,6):
temp1 = unpickle(f'./data_batch_{i}')
X_train.append(temp1['data'])
y_train.append(temp1['labels'])
X_test = []
y_test = []
temp2 = unpickle('./test_batch')
X_test.append(temp2['data'])
y_test.append(temp2['labels'])
使用np.concatenate()方法级联所有训练数据】
X_train = np.vstack(X_train)
y_train = np.hstack(y_train)
#显示其中一张图片
plt.figure(figsize=(2.5,3.5))
plt.imshow(X_train[0].reshape(3,32,32).transpose([1,2,0]))
#导入算法包
from sklearn.svm import SVC
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.decomposition import PCA
from sklearn.neighbors import KNeighborsClassifier
使用PCA进行降维
pca = PCA(n_components=100,whiten=True)
%time X_train_pca = pca.fit_transform(X_train)
#转换数据
X_test = np.array(X_test).reshape(-1,3,32,32).transpose([0,2,3,1]).reshape(10000,-1)
X_test_pca = pca.transform(X_test)
#调用算法写到一个管道里
estimators = {
'logistic':LogisticRegression(max_iter=10000),
'rfc':RandomForestClassifier(n_estimators=100),
'svc':SVC(),
'knn':KNeighborsClassifier()
}
#使用降维后的数据进行训练,预测,并比较哪个算法得分高
preds = {}
for name,estimator in estimators.items():
%time estimator.fit(X_train_pca,y_train)
y_ = estimator.predict(X_test_pca)
preds[name] = y_
pca = PCA(n_components=100,whiten=True)
X_train_pca = pca.fit_transform(X_train)
print(name,estimator.score(X_train_pca,y_train))
Wall time: 3.92 s
logistic 0.4015
Wall time: 2min 45s
Compiler : 1.21 s
Parser : 2.08 s
rfc 0.99996
Wall time: 17min 3s
Compiler : 163 ms
Parser : 686 ms
svc 0.71868
Wall time: 120 ms
knn 0.45462
#发现svc得分较好,选用svc算法进行预测
svc = SVC()
svc.fit(X_train_pca,y_train)
y_ = svc.predict(X_test_pca)
#选用算法预测与测试数据的对比结果展示,取100张图片
icon_name = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
n = np.random.randint(0,1000,100)
plt.figure(figsize=(10*3,10*3.5))
for i,index in enumerate(n):
axes = plt.subplot(10, 10, i +1)
axes.imshow(X_train[i].reshape(3,32,32).transpose(1,2,0), cmap='gray')
axes.axis('off')
if y_[i] != y_test[index]:
axes.set_title(f'True:{icon_name[y_test[index]]}\nPredict:{icon_name[y_[i]]}',fontdict={'fontsize':15,'color':'r'})