# Step 3)训练过程
# -----------------------------------------------------
clf.fit(X_train,y_train);
# =======================================================
# Step 4)训练结果评估
# -----------------------------------------------------
if args.score:
print("Score in train set:",clf.score(X_train,y_train));
print("Score in test set:",clf.score(X_test,y_test));
# =======================================================
# Step 5)模型保存
# -----------------------------------------------------
from joblib import dump, load;
if args.output_model:
dump(clf, strModelName+"_"+strModelIdx+"_time"+'.joblib');
clf2 = load(strModelName+"_"+strModelIdx+"_time"+'.joblib');
'''
y_test_pred = clf2.predict(X_test[0:10]);
print("Predict out:",y_test_pred);
print("Label :",y_test[0:10]);
'''
# =======================================================
# 计算训练集和测试集的混淆矩阵
from sklearn.metrics import confusion_matrix;
if args.confusion_matrix:
y_train_pred = clf.predict(X_train);
y_test_pred = clf.predict(X_test);
resultOut = confusion_matrix(y_train.to_numpy(),y_train_pred);
print("\n\nconfusion matrix in Train Set:")
print(resultOut);
print("=============")
print("\n\nconfusion matrix in Test Set:")
resultOut = confusion_matrix(y_test.to_numpy(),y_test_pred);
print(resultOut);
else:
print("Training finish.")