距离新的一年还有14天,是时候要冲冲KPI了( _)
没错,这又是一篇改错的水贴,百无聊赖记录一下改错经历,虽然期末大作业还没写Uェ*U
一、问题描述
应用线性判别分析Linear Discriminant Analysis对经典手写数字数据集进行分类
首先导一下要用到的包和数据集,然后用LinearDiscriminantAnalysis()对数据集进行训练,然后就报错了ValueError: Found array with dim 3. LinearDiscriminantAnalysis expected <= 2.
import numpy as np
# 数字图像
x_train = np.load('mnist_x_train.npy') # 训练集数据
x_test = np.load('mnist_x_test.npy') # 测试集数据
# 对应的标签,即 0, 1, 2, ..., 9
y_train = np.load('mnist_y_train.npy')
y_test = np.load('mnist_y_test.npy')
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
# 应用Linear Discriminant Analysis来分类
LDA = LinearDiscriminantAnalysis()
LDA.fit(x_train, y_train)
</