训练:
def main():
(train, train_labels), (test, test_labels) = XX_load_data()
train = np.array(train / 255., dtype=np.float32)
test = np.array(test / 255., dtype=np.float32)
knn = cv2.ml.KNearest_create()
print(train.shape, train_labels.shape)
print(test.shape, test_labels.shape)
knn.train(train, cv2.ml.ROW_SAMPLE, train_labels)
# knn.save("XXX")
其中 XX_load_data 为自己定义的导入 数据 公式
预测:
knn = cv2.ml.KNearest_load(r"XXX")
(train, train_labels), (test, test_labels) = Cap_load_data()
print(test.shape, test_labels.shape)
test = np.array(test / 255., dtype=np.float32)
train = np.array(train / 255., dtype=np.float32)
# src = test[0].reshape(0, 1)
# print(test[0])
ret, result, neighbours, dist = knn.findNearest(test, k=3)