具体代码注释及算法原理可见博客https://andyguo.blog.csdn.net/article/details/104336532https://andyguo.blog.csdn.net/article/details/104336532
本文仅提供可运行代码及所需数据(是在以上博客的基础上修改了个别错误)
import numpy as np
import pandas as pd
from sklearn.impute import SimpleImputer
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from sklearn.neighbors import KNeighborsClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.naive_bayes import GaussianNB
def load_dataset(feature_paths, label_paths):
feature = np.ndarray(shape=(0,41))
label = np.ndarray(shape=(0,1))
for file in feature_paths:
df = pd.read_table(file, delimiter=',', na_values='?', header=None)
imp = SimpleImputer(missing_values=pd.NA, strategy='mean', verbose=0)
imp.fit(df)
df = imp.transform(df)
feature = np.concatenate((feature,df))
for file in label_paths:
df = pd.read_table(file,header=None)
label = np.concatenate((label,df))
label = np.ravel(label)
return feature, label
if __name__ == '__main__':
featurePaths = ['A/A.feature','B/B.feature','C/C.feature','D/D.feature','E/E.feature']
labelPaths = ['A/A.label','B/B.label','C/C.label','D/D.label','E/E.label']
x_train, y_train = load_dataset(featurePaths[:4], labelPaths[:4])
x_test, y_test = load_dataset(featurePaths[4:], labelPaths[4:])
print("Start training knn")
knn = KNeighborsClassifier().fit(x_train, y_train)
print("Training done")
answer_knn = knn.predict(x_test)
print("Prediction done")
print("Start training DT")
dt = DecisionTreeClassifier().fit(x_train, y_train)
print("Training done")
answer_dt = dt.predict(x_test)
print("Prediction done")
print("Start training Bayes")
gnb = GaussianNB().fit(x_train, y_train)
print("Training done")
answer_gnb = gnb.predict(x_test)
print("Prediction done")
print("\n\nThe classification report for knn:")
print(classification_report(y_test, answer_knn))
print("\n\nThe classification report for dt:")
print(classification_report(y_test, answer_dt))
print("\n\nThe classification report for gnb:")
print(classification_report(y_test, answer_gnb))
数据链接:
链接:https://pan.baidu.com/s/1zFwrucWuJvk4X9hakQzufw?pwd=o0aj
提取码:o0aj
代码运行结果如下