网格搜索
class sklearn.model_selection.GridSearchCV(
estimator :考虑优化的估计器
param_grid : dict or list of dictionaries,希望进行搜索的参数阵
scoring = None : string/callable/list/tuple/dict/None,模型评分方法
fit_params = None, n_jobs = 1, cv = None
iid = True :数据是否在各fold间均匀分布,此时将直接最小化总样本的损失函数refit = True :是否使用发现的最佳参数重新拟合估计器
verbose = 0, pre_dispatch = '2*n_jobs ', error_score = 'raise’return_train_score = True :是否返回训练集的评分
)
GridSearchCV类的属性:
cv_results_ :字典格式的参数列表,可被直接转换为pandas数据框
best_estimator_ :网格搜索得出的最佳模型
best_score_ :最佳模型的平均交互验证得分
best_params_ : dict,最佳模型的参数设定
best_index_ : int,最佳模型对应的索引值
scorer_ : function or a dict,用于选择最佳模型的评分函数
n_splits_ : int,交叉验证的拆分数
GridSearchCV类的方法:
decision_function ( *args,**kwargs) :调用筛选出的最佳模型并返回预测结果其余标准API接函数
decision_function返回数据点属于每个类别的判定系数,若为正数,则代表该点属于这一类,负数则表示该点不属于这一类。判定系数的绝对值越大,判断的可信度越高。
from sklearn import svm
from sklearn import datasets
from sklearn.model_selection import GridSearchCV
import pandas as pd
iris = datasets.load_iris()
parameters = {'kernel':('linear','rbf'),'C':[1.10]} #实例化
svc = svm.SVC(probability = True)
clf = GridSearchCV(svc,parameters)
clf.fit(iris.data,iris.target)
GridSearchCV(cv='warn', error_score='raise-deprecating',
estimator=SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0,
decision_function_shape='ovr', degree=3,
gamma='auto_deprecated', kernel='rbf', max_iter=-1,
probability=True, random_state=None, shrinking=True,
tol=0.001, verbose=False),
iid='warn', n_jobs=None,
param_grid={'C': [1.1], 'kernel': ('linear', 'rbf')},
pre_dispatch='2*n_jobs', refit=True, return_train_score=False,
scoring=None, verbose=0)
from sklearn.svm import SVC
#显示所有拟合模型的参数设定
pd.DataFrame(clf.cv_results_)
mean_fit_time | std_fit_time | mean_score_time | std_score_time | param_C | param_kernel | params | split0_test_score | split1_test_score | split2_test_score | mean_test_score | std_test_score | rank_test_score | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 0.001341 | 0.000485 | 0.000000 | 0.00000 | 1.1 | linear | {'C': 1.1, 'kernel': 'linear'} | 1.000000 | 0.960784 | 1.000000 | 0.986667 | 0.018577 | 1 |
1 | 0.005318 | 0.005421 | 0.000333 | 0.00047 | 1.1 | rbf | {'C': 1.1, 'kernel': 'rbf'} | 0.980392 | 0.960784 | 0.979167 | 0.973333 | 0.009021 | 2 |
clf.best_params_
{'C': 1.1, 'kernel': 'linear'}
print(clf.decision_function(iris.data))#网格搜索
[[ 2.24627744 1.29829892 -0.30632837]
[ 2.23781119 1.29693571 -0.30471538]
[ 2.24548583 1.29718879 -0.30559412]
[ 2.23591041 1.29589056 -0.30389728]
[ 2.24795778 1.29822418 -0.3064484 ]
[ 2.23752685 1.29735806 -0.30495212]
[ 2.2434869 1.29624146 -0.30483534]
[ 2.24100113 1.29748846 -0.30534443]
[ 2.23661182 1.2953947 -0.30365738]
[ 2.2375786 1.2973143 -0.30492944]
[ 2.24558626 1.29895555 -0.30665596]
[ 2.23718064 1.29655761 -0.3044249 ]
[ 2.23997815 1.29729146 -0.30513039]
[ 2.2512266 1.29724082 -0.30622266]
[ 2.25842469 1.3013034 -0.30927775]
[ 2.25245732 1.29965175 -0.30772804]
[ 2.25194914 1.29891281 -0.3072508 ]
[ 2.24452273 1.29771579 -0.30581239]
[ 2.23723948 1.29851098 -0.30565066]
[ 2.24582594 1.29783236 -0.30600783]
[ 2.23172213 1.29768867 -0.30467686]
[ 2.24234659 1.29705619 -0.30520883]
[ 2.25994866 1.29879541 -0.30806837]
[ 2.22220568 1.29476653 -0.30201257]
[ 2.22238071 1.29523835 -0.30233811]
[ 2.22837826 1.29635472 -0.30354531]
[ 2.23246711 1.29580915 -0.30353915]
[ 2.24249366 1.29815377 -0.30588644]
[ 2.24452963 1.29837334 -0.30620727]
[ 2.23329426 1.2959152 -0.30367832]
[ 2.23097932 1.29600008 -0.30353221]
[ 2.23636238 1.29729451 -0.30481098]
[ 2.25330507 1.29966903 -0.30782355]
[ 2.2557716 1.30030318 -0.30843809]
[ 2.23545141 1.29669764 -0.30435806]
[ 2.2485707 1.29833029 -0.30657146]
[ 2.2491945 1.29959712 -0.30737741]
[ 2.24969301 1.29855531 -0.30681499]
[ 2.24251854 1.29603049 -0.30461341]
[ 2.24086415 1.29773868 -0.30548479]
[ 2.24814273 1.29786457 -0.3062567 ]
[ 2.22578517 1.29425389 -0.30197265]
[ 2.24585572 1.29640343 -0.30517211]
[ 2.22992334 1.29464894 -0.30258346]
[ 2.22639242 1.29551952 -0.3028421 ]
[ 2.23584715 1.29603506 -0.30398092]
[ 2.24398491 1.29802466 -0.30594642]
[ 2.24190045 1.29650998 -0.30483816]
[ 2.24570962 1.29872219 -0.30652718]
[ 2.24310378 1.29771639 -0.30567798]
[-0.25888157 2.27043006 0.84961887]
[-0.25420288 2.26613432 0.85736765]
[-0.26566202 2.26422889 1.03089525]
[-0.24725457 2.26389816 0.83953777]
[-0.26167944 2.26301954 0.97285293]
[-0.2558817 2.26139866 0.91727953]
[-0.25960377 2.26047526 0.98289875]
[-0.19807268 2.27437013 0.7462427 ]
[-0.25819096 2.26876747 0.85971716]
[-0.23803641 2.26291864 0.81574102]
[-0.22433167 2.27078106 0.76857795]
[-0.24705163 2.26470632 0.8338466 ]
[-0.24425962 2.27327002 0.78539706]
[-0.26147994 2.26048024 1.01994941]
[-0.21600922 2.27333011 0.75641414]
[-0.25140742 2.2712808 0.81140066]
[-0.2559883 2.25665492 0.98796188]
[-0.23905423 2.273164 0.77761803]
[-0.26533223 2.25720279 1.11452981]
[-0.23621673 2.27128811 0.78004053]
[-0.26495855 2.24680537 1.16854301]
[-0.23955415 2.2718155 0.78305262]
[-0.27060306 2.25321773 1.1785179 ]
[-0.26019837 2.26448894 0.92626119]
[-0.24921993 2.27097946 0.8065506 ]
[-0.25268911 2.26997762 0.82331433]
[-0.26531537 2.26507325 1.00569519]
[-0.27032769 2.25456052 1.17143878]
[-0.25759465 2.26029055 0.95342037]
[-0.20874231 2.27830251 0.74273904]
[-0.23361723 2.27122032 0.77690547]
[-0.22625653 2.27409475 0.76152715]
[-0.23467384 2.27203226 0.77563407]
[-0.27318245 1.24299275 2.21645134]
[-0.25579586 2.25427719 1.025428 ]
[-0.25255831 2.26139138 0.88790408]
[-0.2609764 2.26533568 0.92399971]
[-0.25977858 2.26537697 0.90938447]
[-0.23921988 2.26768065 0.79813843]
[-0.24402523 2.26517928 0.82109608]
[-0.2540305 2.26200534 0.89343164]
[-0.25783181 2.26278911 0.92107182]
[-0.24081212 2.27029624 0.79080786]
[-0.20208621 2.27458063 0.74730352]
[-0.2479946 2.26442341 0.83926823]
[-0.24139608 2.26915346 0.7964656 ]
[-0.24494709 2.26657115 0.81671407]
[-0.24899229 2.26941618 0.81397179]
[-0.16817711 2.27765932 0.73375002]
[-0.24293044 2.26735753 0.80760147]
[-0.28707384 1.14947354 2.28119096]
[-0.27544166 1.22551722 2.24241259]
[-0.28572654 1.21950229 2.26763227]
[-0.28075419 1.22722381 2.253935 ]
[-0.28496263 1.19919165 2.27167675]
[-0.29221928 1.20120343 2.28271138]
[-0.26350515 1.23390222 2.19566155]
[-0.28872304 1.22449799 2.27171205]
[-0.28536738 1.22270759 2.26580515]
[-0.28702396 1.1962758 2.27552215]
[-0.27295679 1.24123017 2.21846124]
[-0.27867286 1.23154175 2.24615515]
[-0.28099808 1.22753736 2.25431583]
[-0.27600114 1.21779044 2.24849707]
[-0.27838341 1.19196788 2.26253449]
[-0.27824419 1.21747936 2.25346649]
[-0.27892215 1.23592175 2.24334747]
[-0.29050259 1.20935777 2.27863355]
[-0.29632366 1.14630194 2.29280364]
[-0.27446017 1.24383688 2.21985224]
[-0.28353826 1.21366733 2.26542492]
[-0.27210994 1.22438478 2.23485951]
[-0.29332401 1.2025653 2.28418614]
[-0.2716128 1.24523345 2.20596604]
[-0.28175043 1.22264441 2.25843786]
[-0.28427267 1.23661991 2.25666445]
[-0.26887359 1.24715529 2.1886087 ]
[-0.2689827 1.24534071 2.19396581]
[-0.28314739 1.20961778 2.26596079]
[-0.28185127 1.24735998 2.24067838]
[-0.28776392 1.22705696 2.26896304]
[-0.2869051 1.23657142 2.26293051]
[-0.28371257 1.20202648 2.26897129]
[-0.27174181 1.25176897 2.18997946]
[-0.27988353 1.23805115 2.2441591 ]
[-0.2888502 1.21427543 2.27478929]
[-0.28160162 1.19629383 2.2668145 ]
[-0.27824449 1.23534175 2.24198858]
[-0.26693842 1.24624323 2.18133534]
[-0.27913621 1.23429921 2.24525582]
[-0.28339523 1.2030499 2.26818344]
[-0.27624857 1.23436839 2.23730158]
[-0.27544166 1.22551722 2.24241259]
[-0.28569412 1.20057856 2.27257418]
[-0.28405792 1.19376775 2.27133323]
[-0.27826174 1.22537803 2.24926502]
[-0.27556832 1.23623495 2.23346001]
[-0.27601105 1.23598869 2.23502853]
[-0.2782753 1.21132467 2.25621161]
[-0.27247876 1.23666418 2.22329881]]
随机搜索
在不明确可能的参数候选值时,可以在指定的参数值分布中进行取样,实现对参数的随机搜索。class sklearn.model_selection.RandomizedSearchCV(
estimator :
param_distributions : dict,希望进行搜索的参数字典
n_iter = 10 : int,考虑抽取出的参数组合数,该参数用于控制计算总量
scoring = None, fit _params = None, n_jobs = 1, iid = True
refit = True :使用整个数据集重新fit搜索到的最佳参数组合模型
cv = None, verbose =0
pre_dispatch = '2*n_jobs ', random_state = Noneerror_score = 'raise',
return_train_score = True
)
RandomizedSearchCV类的属性(和GridSearchCV类相同):
cv_results_ :字典格式的参数列表,可被直接转换为pandas数据框
best_estimator_ : 网格搜索得出的最佳模型
best_score_ :最佳模型的平均交互验证得分
best_params_ : dict,最佳模型的参数设定
best_index_ : int,最佳模型对应的索引值
scorer_ : function or a dict,用于选择最佳模型的评分函数
n_splits_ : int,交叉验证的拆分数
RandomizedSearchCV类的方法(和GridSearchCV类相同)︰
decision_function (*args,**kwargs) :调用最佳模型,并返回预测结果其余标准API接函数
import scipy.stats as stats
from sklearn import datasets
from sklearn.model_selection import RandomizedSearchCV
import pandas as pd
from sklearn.svm import SVC
#导入数据集并设定参数
iris = datasets.load_iris()
parameters2 = {"kernel":('linear','rbf'),
'C':stats.expon(scale=100),
'gamma':stats.expon(scale=1),
'class_weight':('balanced',None)}
svc = SVC()
clf = RandomizedSearchCV(svc,parameters2)
clf.fit(iris.data,iris.target)
pd.DataFrame(clf.cv_results_)
mean_fit_time | std_fit_time | mean_score_time | std_score_time | param_C | param_class_weight | param_gamma | param_kernel | params | split0_test_score | split1_test_score | split2_test_score | mean_test_score | std_test_score | rank_test_score | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 0.000998 | 8.142963e-04 | 0.000344 | 0.000486 | 133.412 | balanced | 1.61514 | rbf | {'C': 133.41183918245417, 'class_weight': 'bal... | 1.000000 | 0.941176 | 0.958333 | 0.966667 | 0.024918 | 6 |
1 | 0.000000 | 0.000000e+00 | 0.000653 | 0.000462 | 15.8825 | None | 1.69397 | linear | {'C': 15.88249498314624, 'class_weight': None,... | 1.000000 | 0.921569 | 1.000000 | 0.973333 | 0.037154 | 1 |
2 | 0.000665 | 4.710309e-04 | 0.000333 | 0.000470 | 37.7037 | balanced | 0.079273 | linear | {'C': 37.703742462767295, 'class_weight': 'bal... | 1.000000 | 0.921569 | 1.000000 | 0.973333 | 0.037154 | 1 |
3 | 0.000997 | 4.052337e-07 | 0.000333 | 0.000471 | 25.3494 | balanced | 0.951097 | rbf | {'C': 25.34942982674127, 'class_weight': 'bala... | 0.980392 | 0.941176 | 1.000000 | 0.973333 | 0.024415 | 1 |
4 | 0.000676 | 4.779710e-04 | 0.000343 | 0.000485 | 245.106 | balanced | 0.663588 | linear | {'C': 245.10614694204088, 'class_weight': 'bal... | 1.000000 | 0.921569 | 1.000000 | 0.973333 | 0.037154 | 1 |
5 | 0.000661 | 4.685966e-04 | 0.000337 | 0.000476 | 22.3682 | None | 0.0174971 | linear | {'C': 22.368243827973412, 'class_weight': None... | 1.000000 | 0.921569 | 1.000000 | 0.973333 | 0.037154 | 1 |
6 | 0.000333 | 4.703588e-04 | 0.000676 | 0.000478 | 227.165 | None | 0.318804 | rbf | {'C': 227.16515187404892, 'class_weight': None... | 0.980392 | 0.901961 | 1.000000 | 0.960000 | 0.042411 | 10 |
7 | 0.000332 | 4.691225e-04 | 0.000000 | 0.000000 | 125.798 | None | 0.11144 | rbf | {'C': 125.79773872651727, 'class_weight': None... | 1.000000 | 0.901961 | 1.000000 | 0.966667 | 0.046442 | 6 |
8 | 0.000000 | 0.000000e+00 | 0.000997 | 0.000027 | 3.60636 | None | 2.07116 | linear | {'C': 3.6063649298660225, 'class_weight': None... | 0.980392 | 0.921569 | 1.000000 | 0.966667 | 0.033333 | 6 |
9 | 0.000665 | 4.705275e-04 | 0.000322 | 0.000455 | 61.6049 | None | 0.429071 | rbf | {'C': 61.60492358027872, 'class_weight': None,... | 1.000000 | 0.901961 | 1.000000 | 0.966667 | 0.046442 | 6 |
clf.best_params_
{'C': 15.88249498314624,
'class_weight': None,
'gamma': 1.693966803275749,
'kernel': 'linear'}