看 Machine Learning for OpenCV 4 second Edition, Aditya Sharma et al(中文名: 机器学习-使用OpenCV、Python和scikit-learn进行智能图像处理(原书第二版), 刘冰 译, 机械工业出版社)
5.2.1 P96
该章节是根据患者年龄、性别、血压、胆固醇浓度等因素(data)判断应当开什么药(target)的建立决策树的机器学习过程,代码在github是开源的,因此在这不加引用的给出:
data = [
{'age': 33, 'sex': 'F', 'BP': 'high', 'cholesterol': 'high', 'Na': 0.66, 'K': 0.06, 'drug': 'A'},
{'age': 77, 'sex': 'F', 'BP': 'high', 'cholesterol': 'normal', 'Na': 0.19, 'K': 0.03, 'drug': 'D'},
{'age': 88, 'sex': 'M', 'BP': 'normal', 'cholesterol': 'normal', 'Na': 0.80, 'K': 0.05, 'drug': 'B'},
{'age': 39, 'sex': 'F', 'BP': 'low', 'cholesterol': 'normal', 'Na': 0.19, 'K': 0.02, 'drug': 'C'},
{'age': 43, 'sex': 'M', 'BP': 'normal', 'cholesterol': 'high', 'Na': 0.36, 'K': 0.03, 'drug': 'D'},
{'age': 82, 'sex': 'F', 'BP': 'normal', 'cholesterol': 'normal', 'Na': 0.09, 'K': 0.09, 'drug': 'C'},
{'age': 40, 'sex': 'M', 'BP': 'high', 'cholesterol': 'normal', 'Na': 0.89, 'K': 0.02, 'drug': 'A'},
{'age': 88, 'sex': 'M', 'BP': 'normal', 'cholesterol': 'normal', 'Na': 0.80, 'K': 0.05, 'drug': 'B'},
{'age': 29, 'sex': 'F', 'BP': 'high', 'cholesterol': 'normal', 'Na': 0.35, 'K': 0.04, 'drug': 'D'},
{'age': 53, 'sex': 'F', 'BP': 'normal', 'cholesterol': 'normal', 'Na': 0.54, 'K': 0.06, 'drug': 'C'},
{'age': 36, 'sex': 'F', 'BP': 'high', 'cholesterol': 'high', 'Na': 0.53, 'K': 0.05, 'drug': 'A'},
{'age': 63, 'sex': 'M', 'BP': 'low', 'cholesterol': 'high', 'Na': 0.86, 'K': 0.09, 'drug': 'B'},
{'age': 60, 'sex': 'M', 'BP': 'low', 'cholesterol': 'normal', 'Na': 0.66, 'K': 0.04, 'drug': 'C'},
{'age': 55, 'sex': 'M', 'BP': 'high', 'cholesterol': 'high', 'Na': 0.82, 'K': 0.04, 'drug': 'B'},
{'age': 35, 'sex': 'F', 'BP': 'normal', 'cholesterol': 'high', 'Na': 0.27, 'K': 0.03, 'drug': 'D'},
{'age': 23, 'sex': 'F', 'BP': 'high', 'cholesterol': 'high', 'Na': 0.55, 'K': 0.08, 'drug': 'A'},
{'age': 49, 'sex': 'F', 'BP': 'low', 'cholesterol': 'normal', 'Na': 0.27, 'K': 0.05, 'drug': 'C'},
{'age': 27, 'sex': 'M', 'BP': 'normal', 'cholesterol': 'normal', 'Na': 0.77, 'K': 0.02, 'drug': 'B'},
{'age': 51, 'sex': 'F', 'BP': 'low', 'cholesterol': 'high', 'Na': 0.20, 'K': 0.02, 'drug': 'D'},
{'age': 38, 'sex': 'M', 'BP': 'high', 'cholesterol': 'normal', 'Na': 0.78, 'K': 0.05, 'drug': 'A'}
]
#生成新数据
import random
def generateBasorexiaData(num_entries):
# we will save our new entries in this list
list_entries = []
for entry_count in range(num_entries):
new_entry = {}
new_entry['age'] = random.randint(20,100)
new_entry['sex'] = random.choice('M','F')
new_entry['BP'] = random.choice('low','high','normal')
new_entry['cholesterol'] = random.choice('low','high','normal')
new_entry['Na'] = random.random()
new_entry['K'] = random.random()
new_entry['drug'] = random.choice('A','B','C','D')
list_entries.append(new_entry)
return list_entries
# 去除drug项,因为这实质上是target
target = [d['drug'] for d in data]
[d.pop('drug') for d in data]
import matplotlib.pyplot as plt
plt.style.use("ggplot")
age = [d['age'] for d in data]
sodium = [d['Na'] for d in data]
potassium = [d['K'] for d in data]
target = [ord(t) - 65 for t in target]
plt.subplot(221)
plt.scatter(sodium, potassium, c=target,s=40)
plt.xlabel('sodium (Na)')
plt.ylabel('potassium (K)')
plt.subplot(222)
plt.scatter(age, potassium, c=target, s=40)
plt.xlabel('age')
plt.ylabel('potassium (K)')
plt.subplot(223)
plt.scatter(age, sodium, c=target, s=40)
plt.xlabel('age')
plt.ylabel('sodium (Na)')
# 数据处理 DictVectorizer
from sklearn.feature_extraction import DictVectorizer
vec = DictVectorizer(sparse=False)
data_pre = vec.fit_transform(data) #独热编码
# 兼容数据
import numpy as np
data_pre = np.array(data_pre, dtype=np.float32)
target = np.array(target, dtype=np.float32)
# 拆分训练集
import sklearn.model_selection as ms
X_train, X_test, y_train, y_test = ms.train_test_split(data_pre, target, test_size=5, random_state=42)
# 构建决策树
import cv2
dtree = cv2.ml.DTrees_create()
dtree.train(X_train, cv2.ml.ROW_SAMPLE, y_train)
y_pred = dtree.predict(X_test)
# 这个经典的决策树存在明显过拟合的问题,即它记住了所有的正确答案。
主要是应用了cv2.ml.DTrees_create类
但是在运行的时候会出现一个错误:
我实在是百思不得其解。
data是字典列表,列表的长度只有20,经过train_test_split之后,只剩下15个,这个长度不断特别长。数据是经过处理与opencv兼容的。更离谱的是,将train_test_split参数test_size改为19(即只训练一个数据点),还是vector too long...
离谱...真的太离谱了...
我看这个问题之前没有别人遇到过...就在这里记录一下,我比较菜,没能解决它...