**不多说,直接上代码,为了防止直接抄而不是为了学习,我决定把相关的库就不放上去了,自己上网搜索库中包含的方法即可
def load_data():
data = pd.read_csv()#需要输入相关文件路径,如果是其他文件的需要查询pandas的read_其他格式
lable = pd.read_csv()
data_lable = data.merge(lable, how="left", on="USRID")
data_lable.drop(['USRID'], axis=1, inplace=True)
columns = data_lable.columns.tolist()
# print(columns)
feature_columns = [i for i in columns if i != "FLAG"]
# print(feature_columns)
data_array = data_lable[feature_columns].values#数据集
lable_array = data_lable['FLAG'].values#标签集
return train_test_split(data_array, lable_array, test_size = 0.25, random_state = 81,stratify = lable_array)#用于随机将样本集合划分为训练集 和测试集,并返回划分好的训练集和测试集数据。
def test_decision_tree(*data):
X_train,X_test,y_train,y_test=data
clf = DecisionTreeClassifier(criterion="entropy", max_d