1. # -*- coding: utf-8 -*-
2. import numpy as np
3. from sklearn import tree
4. import matplotlib.pyplot as plt
5. # 处理数据
6. filename = './data/Wine.csv'
7. labelname = './data/label_wine.csv'
8. data = np.loadtxt(open(filename, "rb"), delimiter=",", skiprows=0)
9. label = np.loadtxt(open(labelname, "rb"), dtype=int, delimiter=",", skiprows=0)
10. # 80%数据用于训练,20%数据用于测试
11. data_train = data[:int(data.shape[0] * 0.8)]
12. label_train = label[:int(label.shape[0] * 0.8)]
13. data_test = data[int(data.shape[0] * 0.8):]
14. label_test = label[int(label.shape[0] * 0.8):]
15. # 训练模型
16. clf = tree.DecisionTreeClassifier()
17. clf.fit(data_train, label_train)
18. label_ predict = clf.predict(data_test)
19. # 与真实标签比较
20. x = range(data_test.shape[0])
21. fig = plt.figure()
22. ax1 = fig.add_subplot(211)
23. ax2 = fig.add_subplot(212)
24. ax1.set_title('Predict cluster')
25. ax2.set_title('True cluster')
26. plt.xlabel('samples')
27. plt.ylabel('label')
28. ax1.scatter(x, label_predict, c=label_predict, marker='o')
29. ax2.scatter(x, label_test, c=label_test, marker='s')
30. plt.show()