sklearn介绍
- 基于Python
- 实现了几乎所有主流机器学习算法
-
简单易用的API定义
-
2010年末首次开源
- 高度活跃
- Python,Cython,Wrapper
- scikit-learn / sklearn
用sklearn数据集描述¶
In [1]:
from sklearn import datasetsasets # 引入sklearn中的糖尿病数据集
In [2]:
datasets.load_diabetes?
In [3]:
diabetes = datasets.load_diabetes()
In [4]:
X = diabetes.data y = diabetes.target
In [5]:
X.shape
Out[5]:
(442, 10)
In [6]:
y.shape
Out[6]:
(442,)
In [7]:
X[:5]
Out[7]:
array([[ 0.03807591, 0.05068012, 0.06169621, 0.02187235, -0.0442235 , -0.03482076, -0.04340085, -0.00259226, 0.01990842, -0.01764613], [-0.00188202, -0.04464164, -0.05147406, -0.02632783, -0.00844872, -0.01916334, 0.07441156, -0.03949338, -0.06832974, -0.09220405], [ 0.08529891, 0.05068012, 0.04445121, -0.00567061, -0.04559945, -0.03419447, -0.03235593, -0.00259226, 0.00286377, -0.02593034], [-0.08906294, -0.04464164, -0.01159501, -0.03665645, 0.01219057, 0.02499059, -0.03603757, 0.03430886, 0.02269202, -0.00936191], [ 0.00538306, -0.04464164, -0.03638469, 0.02187235, 0.00393485, 0.01559614, 0.00814208, -0.00259226, -0.03199144, -0.04664087]])
In [8]:
y[:10]
Out[8]:
array([151., 75., 141., 206., 135., 97., 138., 63., 110., 310.])
In [9]:
feature_name = ['age', 'sex', 'bmi', 'bp', 's1', 's2', 's3', 's4', 's5', 's6']
In [10]:
from sklearn.linear_model import Lasso import numpy as np from sklearn.model_selection import GridSearchCV
In [11]:
X_train = diabetes.data[:310] y_train = diabetes.target[:310] X_test = diabetes.data[310:] y_test = diabetes.data[310:]
In [12]:
# 线性模型加L1正则化,通过构造一个惩罚函数得到一个较为精炼的模型,使得它压缩一些回归系数 lasso = Lasso(random_state=0) alphas = np.logspace(-4, -0.5, 30)
In [13]:
estimator = GridSearchCV(lasso, dict(alpha = alphas)) # 估计器
In [14]:
estimator.fit(X_train, y_train)
Out[14]:
GridSearchCV(cv=None, error_score='raise', estimator=Lasso(alpha=1.0, copy_X=True, fit_intercept=True, max_iter=1000, normalize=False, positive=False, precompute=False, random_state=0, selection='cyclic', tol=0.0001, warm_start=False), fit_params=None, iid=True, n_jobs=1, param_grid={'alpha': array([1.00000e-04, 1.32035e-04, 1.74333e-04, 2.30181e-04, 3.03920e-04, 4.01281e-04, 5.29832e-04, 6.99564e-04, 9.23671e-04, 1.21957e-03, 1.61026e-03, 2.12611e-03, 2.80722e-03, 3.70651e-03, 4.89390e-03, 6.46167e-03, 8.53168e-03, 1.12648e-02, 1.48735e-02, 1.96383e-02, 2.59294e-02, 3.42360e-02, 4.52035e-02, 5.96846e-02, 7.88046e-02, 1.04050e-01, 1.37382e-01, 1.81393e-01, 2.39503e-01, 3.16228e-01])}, pre_dispatch='2*n_jobs', refit=True, return_train_score='warn', scoring=None, verbose=0)
In [15]:
estimator.best_score_
Out[15]:
0.4654063759023531
In [16]:
estimator.best_estimator_
Out[16]:
Lasso(alpha=0.02592943797404667, copy_X=True, fit_intercept=True, max_iter=1000, normalize=False, positive=False, precompute=False, random_state=0, selection='cyclic', tol=0.0001, warm_start=False)
In [17]:
estimator.predict(X_test)
Out[17]:
array([203.42104984, 177.6595529 , 122.62188598, 212.81136958, 173.61633075, 114.76145025, 202.36033584, 171.70767813, 164.28694562, 191.29091477, 191.41279009, 288.2772433 , 296.47009002, 234.53378413, 210.61427168, 228.62812055, 156.74489991, 225.08834492, 191.75874632, 102.81600989, 172.373221 , 111.20843429, 290.22242876, 178.64605207, 78.13722832, 86.35832297, 256.41378529, 165.99622543, 121.29260976, 153.48718848, 163.09835143, 180.0932902 , 161.4330553 , 155.80211635, 143.70181085, 126.13753819, 181.06471818, 105.03679977, 131.0479936 , 90.50606427, 252.66486639, 84.84786067, 59.41005358, 184.51368208, 201.46598714, 129.96333913, 90.65641478, 200.10932516, 55.2884802 , 171.60459062, 195.40750666, 122.14139787, 231.72783897, 159.49750022, 160.32104862, 165.53701866, 260.73217736, 259.77213787, 204.69526082, 185.66480969, 61.09821961, 209.9214333 , 108.50410841, 141.18424239, 126.10337002, 174.32819351, 214.4947322 , 162.1789921 , 160.57776438, 134.11449594, 171.63076427, 71.71500885, 263.46782314, 113.73653782, 112.76227977, 134.37721414, 110.67874472, 98.67153573, 157.2591359 , 78.32019218, 265.97090212, 57.85502185, 100.38532691, 101.91670102, 277.13032245, 168.6443445 , 64.75637937, 184.37359745, 174.74927914, 188.78215433, 181.56001383, 92.74463449, 145.41037529, 257.78620944, 196.57335354, 276.1920927 , 50.66776115, 179.12879963, 200.29366671, 167.29501922, 158.93206689, 156.08070427, 233.38241229, 125.30241353, 167.05404644, 171.66748431, 223.17843095, 156.7055944 , 103.29063169, 84.08205647, 139.87060658, 189.99648341, 200.20182211, 143.61906164, 170.00220231, 112.05886847, 160.76337573, 130.06232976, 261.83022688, 102.24589129, 115.12771477, 119.14505163, 225.96991263, 63.51874043, 134.88829709, 120.01764214, 55.32147904, 189.95346987, 105.8037979 , 120.46197038, 211.35568232, 56.78368048])