该报错主要是由于代码
"plt.plot(data[i, 0], data[i, 1], marker='o', markersize=1, color=color_map[int(label[i])])"
中label[i]不是一个大小为1的数组。
要解决这个问题,应该确保label是一个一维数组,其中每个条目对应于一个数据点的标签。可以通过打印标签来检查标签的形状。在get_fer_data函数中加载后的形状。
将get_fer_data中代码改为:
def get_fer_data(data_path="./imgs_embed_npy.npy",
label_path="./label.npy"):
data = np.load(data_path)
label = np.load(label_path)
label = np.argmax(label, axis=1) # Convert one-hot encoded labels to scalar labels
n_samples, n_features = data.shape
return data, label, n_samples, n_features
即可