目的:one-hot编码方法读取类别标签(文件名形式),并划分数据集生成 .npy二进制文件
过程描述:在载入文件夹下图片和输出图片文件夹名后需要用文件名作为类别标签,方法是用one-hot编码划分数据集,以 .npy二进制文件格式保存
python代码:
# one-hot 编码
le = preprocessing.LabelEncoder()
le.fit(os.listdir(DATA_DIR))
list(le.classes_)
print(list(le.classes_))
y = le.transform(y)
print(x.shape, y.shape)
#划分数据集(.npy格式保存标签文件)
x_train, x_test, y_train, y_test, y_label_train, y_label_test = train_test_split(x, y, y_label, random_state=1, test_size=0.3)
np.save("x_train.npy", x_train); np.save("y_train.npy", y_train); np.save("y_label_train.npy", y_label_train);
np.save("x_test.npy", x_test); np.save("y_test.npy", y_test); np.save("y_label_test.npy", y_label_test)
代码结果: