python 画出决策边界_画出决策边界线--plot_2d_separator.py源代码【来自python机器学习基础教程】...

1 importnumpy as np2 importmatplotlib.pyplot as plt3 from .plot_helpers importcm2, cm3, discrete_scatter4

5 def_call_classifier_chunked(classifier_pred_or_decide, X):6 #The chunk_size is used to chunk the large arrays to work with x86

7 #memory models that are restricted to < 2 GB in memory allocation. The

8 #chunk_size value used here is based on a measurement with the

9 #MLPClassifier using the following parameters:

10 #MLPClassifier(solver='lbfgs', random_state=0,

11 #hidden_layer_sizes=[1000,1000,1000])

12 #by reducing the value it is possible to trade in time for memory.

13 #It is possible to chunk the array as the calculations are independent of

14 #each other.

15 #Note: an intermittent version made a distinction between

16 #32- and 64 bit architectures avoiding the chunking. Testing revealed

17 #that even on 64 bit architectures the chunking increases the

18 #performance by a factor of 3-5, largely due to the avoidance of memory

19 #swapping.

20 chunk_size = 10000

21

22 #We use a list to collect all result chunks

23 Y_result_chunks =[]24

25 #Call the classifier in chunks.

26 for x_chunk innp.array_split(X, np.arange(chunk_size, X.shape[0],27 chunk_size, dtype=np.int32),28 axis=0):29 Y_result_chunks.append(classifier_pred_or_decide(x_chunk))30

31 returnnp.concatenate(Y_result_chunks)32

33

34 def plot_2d_classification(classifier, X, fill=False, ax=None, eps=None,35 alpha=1, cm=cm3):36 #multiclass

37 if eps isNone:38 eps = X.std() / 2.39

40 if ax isNone:41 ax =plt.gca()42

43 x_min, x_max = X[:, 0].min() - eps, X[:, 0].max() +eps44 y_min, y_max = X[:, 1].min() - eps, X[:, 1].max() +eps45 xx = np.linspace(x_min, x_max, 1000)46 yy = np.linspace(y_min, y_max, 1000)47

48 X1, X2 =np.meshgrid(xx, yy)49 X_grid =np.c_[X1.ravel(), X2.ravel()]50 decision_values =classifier.predict(X_grid)51 ax.imshow(decision_values.reshape(X1.shape), extent=(x_min, x_max,52 y_min, y_max),53 aspect='auto', origin='lower', alpha=alpha, cmap=cm)54 ax.set_xlim(x_min, x_max)55 ax.set_ylim(y_min, y_max)56 ax.set_xticks(())57 ax.set_yticks(())58

59

60 def plot_2d_scores(classifier, X, ax=None, eps=None, alpha=1, cm="viridis",61 function=None):62 #binary with fill

63 if eps isNone:64 eps = X.std() / 2.65

66 if ax isNone:67 ax =plt.gca()68

69 x_min, x_max = X[:, 0].min() - eps, X[:, 0].max() +eps70 y_min, y_max = X[:, 1].min() - eps, X[:, 1].max() +eps71 xx = np.linspace(x_min, x_max, 100)72 yy = np.linspace(y_min, y_max, 100)73

74 X1, X2 =np.meshgrid(xx, yy)75 X_grid =np.c_[X1.ravel(), X2.ravel()]76 if function isNone:77 function = getattr(classifier, "decision_function",78 getattr(classifier, "predict_proba"))79 else:80 function =getattr(classifier, function)81 decision_values =function(X_grid)82 if decision_values.ndim > 1 and decision_values.shape[1] > 1:83 #predict_proba

84 decision_values = decision_values[:, 1]85 grr =ax.imshow(decision_values.reshape(X1.shape),86 extent=(x_min, x_max, y_min, y_max), aspect='auto',87 origin='lower', alpha=alpha, cmap=cm)88

89 ax.set_xlim(x_min, x_max)90 ax.set_ylim(y_min, y_max)91 ax.set_xticks(())92 ax.set_yticks(())93 returngrr94

95

96 def plot_2d_separator(classifier, X, fill=False, ax=None, eps=None, alpha=1,97 cm=cm2, linewidth=None, threshold=None,98 linestyle="solid"):99 #binary?

100 if eps isNone:101 eps = X.std() / 2.102

103 if ax isNone:104 ax =plt.gca()105

106 x_min, x_max = X[:, 0].min() - eps, X[:, 0].max() +eps107 y_min, y_max = X[:, 1].min() - eps, X[:, 1].max() +eps108 xx = np.linspace(x_min, x_max, 1000)109 yy = np.linspace(y_min, y_max, 1000)110

111 X1, X2 =np.meshgrid(xx, yy)112 X_grid =np.c_[X1.ravel(), X2.ravel()]113 if hasattr(classifier, "decision_function"):114 decision_values =_call_classifier_chunked(classifier.decision_function,115 X_grid)116 levels = [0] if threshold is None else[threshold]117 fill_levels = [decision_values.min()] + levels +[118 decision_values.max()]119 else:120 #no decision_function

121 decision_values =_call_classifier_chunked(classifier.predict_proba,122 X_grid)[:, 1]123 levels = [.5] if threshold is None else[threshold]124 fill_levels = [0] + levels + [1]125 iffill:126 ax.contourf(X1, X2, decision_values.reshape(X1.shape),127 levels=fill_levels, alpha=alpha, cmap=cm)128 else:129 ax.contour(X1, X2, decision_values.reshape(X1.shape), levels=levels,130 colors="black", alpha=alpha, linewidths=linewidth,131 linestyles=linestyle, zorder=5)132

133 ax.set_xlim(x_min, x_max)134 ax.set_ylim(y_min, y_max)135 ax.set_xticks(())136 ax.set_yticks(())137

138

139 if __name__ == '__main__':140 from sklearn.datasets importmake_blobs141 from sklearn.linear_model importLogisticRegression142 X, y = make_blobs(centers=2, random_state=42)143 clf = LogisticRegression(solver='lbfgs').fit(X, y)144 plot_2d_separator(clf, X, fill=True)145 discrete_scatter(X[:, 0], X[:, 1], y)146 plt.show()

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值