线性回归
【关键词】最小二乘法,线性
1、线性方程
标签数据y和一个特征x成线性关系,则x和y可以表示为如下表达式:
y = w*x + b 其中w和b都是常数,w叫做斜率,b叫做截距
标签数据 y和两个特征x1、x2都成线性关系,则y与x1、x2的关系表达式为:
y = w1x1 +w2x2 + b
令W = (w1,w2),令X=(x1,x2),此时W和X分别称为系数向量和特征向量,y和特征向量X的关系可以表示成
y = W*XT+ b ,在这个方程中W称为系数向量或者权重,X称为特征向量,b为截距
标签y和n个特征x1、x2…、xn都成线性关系,则y与x1、x2、…、xn的关系表达式为:y = w1x1 + w2x2+…+wn*xn + b
令W = (w1,w2,…,wn),X=(x1,x2,…,xn)
则标签y与特征向量X的关系可以表示为
y = W*XT+ b
2、样本带入
现在有一个样本,特征向量为X1 = (x1_1,x1_2,…,x1_n) ,标签为y1,如果y1和X1成线性关系,则有y1 = W*X1T+b
如果现在有m个样本每个样本n个特征,特征向量分别为X1、X1、…、Xm ,且这m个样本的标签分别为y1、y2、…、ym 而且每个标签都和特征向量乘线性关系,则有以下方程:
y1 =WX1 T+ b = w1x1_1 + w2x1_2+ … +wnx1_n + b
y2 =WX2T + b = w1x12_1 + w2x2_2+ … +wnx2_n + b
…
ym =WXmT + b = w1xm_1 + w2xm_2+ … +wnxm_n + b
以上m个方程构成一个方程组,这个方程组中常数为b,已知数y1-ym和X1-Xm,未知数是W,即w1-wm此时这个方程组就是n元一次方程组,现在线性回归目的就是要得到一个线性方程,如果我们得到W这个参数就得到这个方程,即得到了整个线性回归模型
讨论上面的m个方构成的n元一次方程组的解的情况:
当m==n的时候,即样本数量和特征数量相等的时候,方程组有且只有一个解,即线性回归模型有且只有一个W系数向量符合要求
当m<n的时候,即样本数量少于特征数量的时候,方程组有无数个解,即线性回归模型有无数个W系数向量符合要求
当m>n的时候,即样本数量多于特征数量的时候,方程组无解,即线性回归模型找不到合适的W系数向量符合要求
在实际情况中m>>n这种情况非常多的,方程组无解
但是现在要用m个样本训练处这n个回归系数(即回归系数向量W),带入以后肯定无解,即没有完全满足的W系数;如何来推导W的值?
没有完美的W,只能推导最合适的W,即找到一个系数向量W可以尽可能多的满足样本特征向量和标签的对应?
最小二乘法可以解决这个问题
3、算法的推导过程–最小二乘法
最小二乘法的步骤:
1)假设现有最合适的回归系数为W’,则有假设的回归方程为:y’ = W’*XT
2)把所有的m个已知样本的特征向量带入假设的回归方程中
y1’=W’X1T = w1x1_1+w2*x1_2 + …+wnx1_n
y2’=W’*X2T =…
…
Ym’ = W’*XmT = …
此时我们就有假设的方程得到了假设出来的标签y1’、y2、…、ym’这些标签都和W’有关,这些标签和真实的y1、y2、…、ym之间必然存在差别
3)计算y1’、y2’、…ym’和y1、y2、…、ym之间的差异
H = (y1-y1’)2
+ (y2-y2’)2 + … + (ym-ym’)2
把此时y1、y2…ym都是已知,y1’、y2’、…ym’可以通过W’来表示出来,则H就可以表示为:
H=(y1-W’*X1T)2 + (y2-W’*X2T)2+ … + (ym - W’*XmT)2
= (y12 - 2y1W’X1T+ (W’X1T)2) +… + (ym2 - 2ym*W’XmT+ (W’*XmT)2)
= [y12 + …+ym2]-2*[y1W’X1T + …+ymW’XmT] + [(W’*X1T)2+ … (W’*XmT)2]
H方程也称为损失函数
4)求损失函数H的最小值
H的最小值求法:
对H求导数
H’ = -2(y1X1T + … + ymXmT) + 2(W’X1TX1 + …+W’XmTXm)
令Y=(y1,y2,…,ym),X=(X1T,X2T,…,XmT)则
H’ = -2YXT + 2W’XTX
令H’ = 0的
W’XTX = YXT
当XTX这个矩阵可逆的时候可以得到
|XTX| != 0的时候可逆,即矩阵满秩的情况下(当m<n的时候矩阵不满秩,当m>n的时候矩阵满秩)
W’ = (XTX)-1YXT
4、欠拟合与过拟合的优化(模型的正则化)过拟合可引入λ进行惩罚,欠拟合不能使用λ来进行惩罚
在样本量非常小的情况下极易造成过拟合
在线性回归模型中引入正则化系数就可以减小过拟合
L1正则化:
线性回归的最小二乘法的损失函数:
J(W) = (y1-y1’)2+ (y2-y2’)2 +…+(ym-ym’)2
记为:MES(W)
L1正则化就是样在这个损失函数上面引入回归系数的一阶范数作为惩罚项
W =(w1,w2,…,wn)
一阶范数:(|w1|+|w2|+|w3|+…+|wn|)
引入以后损失函数会变成:
L(w) = J(W) + λ(|w1|+|w2|+|w3|+…+|wn|)
此时影响L(W)取值的除了已知那些样本以外,还有λ
引入一阶范数的作为惩罚项的回归优化算法我们称为Lasso回归
L2正则化:
J(W) = (y1-y1’)2+ (y2-y2’)2 + … + (ym-ym’)2
引入回归系数向量的二阶范数:(w12 + w22 + … + wn2)
L(W) = J(W) + λ(w12 + w22 + … + wn2)
L(W)’ = 0
==>
W = (XTX+ λI)-1YXT 其中I代表单位矩阵
此时引入了二阶范数的优化算法称为:岭回归,即Ridge回归
一、普通线性回归
1、原理
分类的目标变量是标称型数据,而回归将会对连续型的数据做出预测。
应当怎样从一大堆数据里求出回归方程呢?
假定输人数据存放在矩阵X中,而回归系数存放在向量W中。那么对于给定的数据X1, 预测结果将会通过
Y=X*W
给出。现在的问题是,手里有一些X和对应的Y,怎样才能找到W呢?
一个常用的方法就是找出使误差最小的W。这里的误差是指预测Y值和真实Y值之间的差值,使用该误差的简单累加将使得正差值和负差值相互抵消,所以我
们采用平方误差。
最小二乘法
平方误差可以写做:
对W求导,当导数为零时,平方误差最小,此时W等于:
例如有下面一张图片:
求回归曲线,得到:
2、实例
from sklearn.linear_model import LinearRegression
from sklearn import datasets
diabetes = datasets.load_diabetes()
diabetes
{'DESCR': '.. _diabetes_dataset:\n\nDiabetes dataset\n----------------\n\nTen baseline variables, age, sex, body mass index, average blood\npressure, and six blood serum measurements were obtained for each of n =\n442 diabetes patients, as well as the response of interest, a\nquantitative measure of disease progression one year after baseline.\n\n**Data Set Characteristics:**\n\n :Number of Instances: 442\n\n :Number of Attributes: First 10 columns are numeric predictive values\n\n :Target: Column 11 is a quantitative measure of disease progression one year after baseline\n\n :Attribute Information:\n - Age\n - Sex\n - Body mass index\n - Average blood pressure\n - S1\n - S2\n - S3\n - S4\n - S5\n - S6\n\nNote: Each of these 10 feature variables have been mean centered and scaled by the standard deviation times `n_samples` (i.e. the sum of squares of each column totals 1).\n\nSource URL:\nhttp://www4.stat.ncsu.edu/~boos/var.select/diabetes.html\n\nFor more information see:\nBradley Efron, Trevor Hastie, Iain Johnstone and Robert Tibshirani (2004) "Least Angle Regression," Annals of Statistics (with discussion), 407-499.\n(http://web.stanford.edu/~hastie/Papers/LARS/LeastAngle_2002.pdf)',
'data': array([[ 0.03807591, 0.05068012, 0.06169621, ..., -0.00259226,
0.01990842, -0.01764613],
[-0.00188202, -0.04464164, -0.05147406, ..., -0.03949338,
-0.06832974, -0.09220405],
[ 0.08529891, 0.05068012, 0.04445121, ..., -0.00259226,
0.00286377, -0.02593034],
...,
[ 0.04170844, 0.05068012, -0.01590626, ..., -0.01107952,
-0.04687948, 0.01549073],
[-0.04547248, -0.04464164, 0.03906215, ..., 0.02655962,
0.04452837, -0.02593034],
[-0.04547248, -0.04464164, -0.0730303 , ..., -0.03949338,
-0.00421986, 0.00306441]]),
'data_filename': 'C:\\Anaconda3\\lib\\site-packages\\sklearn\\datasets\\data\\diabetes_data.csv.gz',
'feature_names': ['age',
'sex',
'bmi',
'bp',
's1',
's2',
's3',
's4',
's5',
's6'],
'target': array([151., 75., 141., 206., 135., 97., 138., 63., 110., 310., 101.,
69., 179., 185., 118., 171., 166., 144., 97., 168., 68., 49.,
68., 245., 184., 202., 137., 85., 131., 283., 129., 59., 341.,
87., 65., 102., 265., 276., 252., 90., 100., 55., 61., 92.,
259., 53., 190., 142., 75., 142., 155., 225., 59., 104., 182.,
128., 52., 37., 170., 170., 61., 144., 52., 128., 71., 163.,
150., 97., 160., 178., 48., 270., 202., 111., 85., 42., 170.,
200., 252., 113., 143., 51., 52., 210., 65., 141., 55., 134.,
42., 111., 98., 164., 48., 96., 90., 162., 150., 279., 92.,
83., 128., 102., 302., 198., 95., 53., 134., 144., 232., 81.,
104., 59., 246., 297., 258., 229., 275., 281., 179., 200., 200.,
173., 180., 84., 121., 161., 99., 109., 115., 268., 274., 158.,
107., 83., 103., 272., 85., 280., 336., 281., 118., 317., 235.,
60., 174., 259., 178., 128., 96., 126., 288., 88., 292., 71.,
197., 186., 25., 84., 96., 195., 53., 217., 172., 131., 214.,
59., 70., 220., 268., 152., 47., 74., 295., 101., 151., 127.,
237., 225., 81., 151., 107., 64., 138., 185., 265., 101., 137.,
143., 141., 79., 292., 178., 91., 116., 86., 122., 72., 129.,
142., 90., 158., 39., 196., 222., 277., 99., 196., 202., 155.,
77., 191., 70., 73., 49., 65., 263., 248., 296., 214., 185.,
78., 93., 252., 150., 77., 208., 77., 108., 160., 53., 220.,
154., 259., 90., 246., 124., 67., 72., 257., 262., 275., 177.,
71., 47., 187., 125., 78., 51., 258., 215., 303., 243., 91.,
150., 310., 153., 346., 63., 89., 50., 39., 103., 308., 116.,
145., 74., 45., 115., 264., 87., 202., 127., 182., 241., 66.,
94., 283., 64., 102., 200., 265., 94., 230., 181., 156., 233.,
60., 219., 80., 68., 332., 248., 84., 200., 55., 85., 89.,
31., 129., 83., 275., 65., 198., 236., 253., 124., 44., 172.,
114., 142., 109., 180., 144., 163., 147., 97., 220., 190., 109.,
191., 122., 230., 242., 248., 249., 192., 131., 237., 78., 135.,
244., 199., 270., 164., 72., 96., 306., 91., 214., 95., 216.,
263., 178., 113., 200., 139., 139., 88., 148., 88., 243., 71.,
77., 109., 272., 60., 54., 221., 90., 311., 281., 182., 321.,
58., 262., 206., 233., 242., 123., 167., 63., 197., 71., 168.,
140., 217., 121., 235., 245., 40., 52., 104., 132., 88., 69.,
219., 72., 201., 110., 51., 277., 63., 118., 69., 273., 258.,
43., 198., 242., 232., 175., 93., 168., 275., 293., 281., 72.,
140., 189., 181., 209., 136., 261., 113., 131., 174., 257., 55.,
84., 42., 146., 212., 233., 91., 111., 152., 120., 67., 310.,
94., 183., 66., 173., 72., 49., 64., 48., 178., 104., 132.,
220., 57.]),
'target_filename': 'C:\\Anaconda3\\lib\\site-packages\\sklearn\\datasets\\data\\diabetes_target.csv.gz'}
data = diabetes.data
target = diabetes.target
data.shape
(442, 10)
from sklearn.model_selection import train_test_split
x_train,x_test,y_train,y_test = train_test_split(data,target,test_size=0.2)
# 创建线性回归模型
lr = LinearRegression()
lr
LinearRegression(copy_X=True, fit_intercept=True, n_jobs=None,
normalize=False)
# 训练
lr.fit(x_train,y_train)
LinearRegression(copy_X=True, fit_intercept=True, n_jobs=None,
normalize=False)
x_train.shape
(353, 10)
x_train[0]
array([-0.08906294, -0.04464164, -0.01159501, -0.03665645, 0.01219057,
0.02499059, -0.03603757, 0.03430886, 0.02269202, -0.00936191])
训练过程:
1)假设一个回归系数W=(w1,w2,…,w10),得到回归方程y=W*X^T
2)将x_train的样本特征带入到回归方程中得到:
y1’ = w1x1_1 + w2x1_2 + … + w10x1_10
y2’ = w1x2_1 + w2x2_2 + … + w10x2_10
…
y352’ = w1x353_1 + w2x353_2 + … + w10*x353_10
3)求y1’、y2’…y353’和y1、y2…y353之间的差异
H = (y1’-y1)^2 + (y2’-y2)^2 + … + (y353’-y353)
=(y1-(w1x1_1 + w2x1_2 + … + w10x1_10))^2 + …+(y353-(w1x353_1 + w2x353_2 + … + w10x353_10))
4)对H求导数,得到最合适的W
# 经过上面的训练以后,得到了线性回归模型lr,lr实际上就是一个线性方程
y_ = lr.predict(x_test)
y_,y_test
(array([154.89352865, 117.66145264, 115.39462132, 263.02148887,
146.79323278, 205.27899278, 247.24868683, 76.58650704,
103.58635735, 169.200138 , 149.83791791, 88.43928491,
155.3842227 , 151.68579974, 214.49443703, 126.54484392,
61.81803479, 193.60691532, 170.24744126, 77.51329315,
134.56098579, 116.34830106, 84.36518853, 234.46506385,
137.48484809, 167.10021854, 165.62819288, 109.71417742,
266.77751244, 61.59933921, 165.79555251, 162.36833465,
169.75759087, 104.94767529, 160.20086893, 96.36430463,
136.63919961, 120.21130918, 185.75507159, 49.03023118,
117.14303835, 257.71458883, 82.94732875, 98.37461205,
52.68519927, 137.12229317, 162.51850782, 106.44530913,
119.06175181, 174.46433491, 179.28787779, 91.51218134,
259.66893177, 223.02673926, 165.98102707, 233.3523202 ,
159.51666131, 145.11491916, 155.7050318 , 170.46016154,
130.9758386 , 194.49753158, 167.88700093, 171.45787869,
118.96031811, 151.68601288, 75.22126957, 196.85985685,
290.20423187, 129.89394565, 140.30786115, 235.54147671,
162.24369658, 233.33357054, 124.80124842, 179.57583407,
196.42418401, 145.21897625, 77.74533059, 229.74545767,
186.99733822, 231.66435237, 76.34429234, 52.15053722,
256.88257267, 101.1362802 , 230.63949963, 188.43220532,
231.67447417]),
array([100., 71., 88., 310., 182., 288., 341., 65., 88., 121., 219.,
64., 131., 202., 259., 144., 43., 292., 277., 104., 214., 253.,
69., 236., 170., 151., 77., 111., 303., 99., 206., 120., 91.,
199., 127., 47., 59., 178., 129., 78., 68., 275., 60., 81.,
90., 50., 109., 89., 66., 258., 217., 84., 310., 225., 151.,
99., 94., 103., 141., 244., 148., 191., 122., 311., 145., 95.,
158., 137., 270., 42., 90., 152., 129., 280., 68., 126., 281.,
142., 55., 208., 90., 128., 134., 72., 132., 84., 275., 67.,
321.]))
检测性能
经验误差
y_t = lr.predict(x_train)
from sklearn.metrics import mean_absolute_error,mean_squared_error
print("经验数据的平均绝对误差:",mean_absolute_error(y_pred=y_t,y_true=y_train))
print("经验数据的均方误差:",mean_squared_error(y_pred=y_t,y_true=y_train))
经验数据的平均绝对误差: 41.49930369562525
经验数据的均方误差: 2670.0577236867643
泛化误差
print("泛化平均绝对误差:",mean_absolute_error(y_pred=y_,y_true=y_test))
print("泛化军方误差:",mean_squared_error(y_pred=y_,y_true=y_test))
泛化平均绝对误差: 50.5681075027226
泛化军方误差: 3690.850980625107
关于过拟合与欠拟合
拟合,所谓拟合指的就是机器学习过程中,更新参数,使得模性不断的契合训练集的过程
# 构造一些训练集
x_train = [[6],[8],[10],[14],[18]]
y_train = [[7],[9],[13],[17.5],[18]]
使用1次多项式进行模型训练拟合
# 创建线性回归模型
lgr = LinearRegression()
lgr.fit(x_train,y_train)
LinearRegression(copy_X=True, fit_intercept=True, n_jobs=None,
normalize=False)
import numpy as np
xx = np.linspace(0,26,100).reshape((100,1))
# 预测xx对应的标签
yy = lgr.predict(xx)
import matplotlib.pyplot as plt
%matplotlib inline
plt.axis([0,25,0,25])
plt.scatter(x_train,y_train)
plt.plot(xx,yy,c="r")
[<matplotlib.lines.Line2D at 0x21aa54c47f0>]
print("线性回归的准确率:",lgr.score(x_train,y_train))
线性回归的准确率: 0.9100015964240102
使用2次多项式进行模型训练的拟合
from sklearn.preprocessing import PolynomialFeatures
# 使用PolynomialFeatures映射出2次多项式的新特征
poly2 = PolynomialFeatures(degree=2)
x_train_poly2 = poly2.fit_transform(x_train)
lgr_poly2 = LinearRegression()
lgr_poly2.fit(x_train_poly2,y_train)
LinearRegression(copy_X=True, fit_intercept=True, n_jobs=None,
normalize=False)
yy_poly2 = lgr_poly2.predict(poly2.transform(xx))
plt.axis([0,25,0,25])
plt.scatter(x_train,y_train)
plt.plot(xx,yy,c="r")
plt.plot(xx,yy_poly2,c='g')
[<matplotlib.lines.Line2D at 0x21aa1f6e8d0>]
采用4次多项式进行模型的训练拟合
poly4 = PolynomialFeatures(degree=4)
x_train_poly4 = poly4.fit_transform(x_train)
lgr_poly4 = LinearRegression()
lgr_poly4.fit(x_train_poly4,y_train)
LinearRegression(copy_X=True, fit_intercept=True, n_jobs=None,
normalize=False)
yy_poly4 = lgr_poly4.predict(poly4.transform(xx))
plt.axis([0,25,0,25])
plt.scatter(x_train,y_train)
plt.plot(xx,yy,c="r",label="degree=1")
plt.plot(xx,yy_poly2,c="g",label="degree=2")
plt.plot(xx,yy_poly4,c="b",label="degree=4")
plt.legend()
<matplotlib.legend.Legend at 0x21aa1d09860>
print("1次多项式的准确率:",lgr.score(x_train,y_train))
print("2次多项式的准确率:",lgr_poly2.score(x_train_poly2,y_train))
print("4次多项式的准确率:",lgr_poly4.score(x_train_poly4,y_train))
1次多项式的准确率: 0.9100015964240102
2次多项式的准确率: 0.9816421639597427
4次多项式的准确率: 1.0
当模型的参数复杂度很低的时候(degree=1),模型不仅在训练集上没有良好的拟合状态,而且测试效果也表现的非常平平,这种现象叫做欠拟合
当我们一味地追求很高的模型复杂度的时候(degree=4),尽管模型几乎完全拟合了所有的训练集,但是模型曲线的波动也非常大,模型就几乎丧失了对未知数据的预测能力,泛化能力依然很差,这种现象称为过拟合
过拟合与欠拟合都会导致模型的泛化能力降低
二、岭回归
1、原理
如果数据的特征比样本点还多应该怎么办?是否还可以使用线性回归和之前的方法来做预测?
答案是否定的,即不能再使用前面介绍的方法。这是因为输入数据的矩阵X不是满秩矩阵。非满秩矩阵在求逆时会出现问题。
为了解决这个问题,统计学家引入了岭回归(ridge regression)的概念
缩减方法可以去掉不重要的参数,因此能更好地理解数据。此外,与简单的线性回归相比,缩减法能取得更好的预测效果。
【注意】在岭回归里面,决定回归模型性能的除了数据算法以外,还有一个缩减值lambda*I
岭回归是加了二阶正则项(lambda*I)的最小二乘,主要适用于过拟合严重或各变量之间存在多重共线性的时候,岭回归是有bias的,这里的bias是为了让variance更小。
归纳总结
1.岭回归可以解决特征数量比样本量多的问题
2.岭回归作为一种缩减算法可以判断哪些特征重要或者不重要,有点类似于降维的效果
3.缩减算法可以看作是对一个模型增加偏差的同时减少方差
岭回归用于处理下面两类问题:
1.数据点少于变量个数
2.变量间存在共线性(最小二乘回归得到的系数不稳定,方差很大)
2、实例
Rigde回归,即岭回归,采用的是L2正则化,即引入二阶范数的λ被作为惩罚项,即λ|W|^2,为了使优化目标最小,这种正则化方法会让参数向量中大部分元素都变得很小,从而压制了参数之间的差异性
查看rigde回归在上面poly4上的表现
from sklearn.linear_model import Ridge
rigde_poly4 = Ridge()
rigde_poly4.fit(x_train_poly4,y_train)
Ridge(alpha=1.0, copy_X=True, fit_intercept=True, max_iter=None,
normalize=False, random_state=None, solver='auto', tol=0.001)
rigde_poly4.score(x_train_poly4,y_train)
0.9941615674726115
查看系数情况
rigde_poly4.coef_ # 回归系数
array([[ 0. , -0.00492536, 0.12439632, -0.00046471, -0.00021205]])
lgr_poly4.coef_
array([[ 0.00000000e+00, -2.51739583e+01, 3.68906250e+00,
-2.12760417e-01, 4.29687500e-03]])
岭回归的参数优化
x_train,x_test,y_train,y_test = train_test_split(data,target,test_size=0.2)
rigde = Ridge(alpha=100) # alpha参数代表惩罚系数λ
rigde
Ridge(alpha=100, copy_X=True, fit_intercept=True, max_iter=None,
normalize=False, random_state=None, solver='auto', tol=0.001)
rigde.fit(x_train,y_train)
Ridge(alpha=100, copy_X=True, fit_intercept=True, max_iter=None,
normalize=False, random_state=None, solver='auto', tol=0.001)
y_ = rigde.predict(x_test)
y_t = rigde.predict(x_train)
print("经验误差:",mean_squared_error(y_pred=y_t,y_true=y_train))
经验误差: 5778.490996069103
print("泛化误差:",mean_squared_error(y_pred=y_,y_true=y_test))
泛化误差: 5845.025458771033
【注意】如果模型是欠拟合,不能用岭回归和Lasso回归来优化
画岭迹线
x_train = 1/(np.arange(1,11) + np.arange(0,10).reshape((10,1)))
x_train
array([[1. , 0.5 , 0.33333333, 0.25 , 0.2 ,
0.16666667, 0.14285714, 0.125 , 0.11111111, 0.1 ],
[0.5 , 0.33333333, 0.25 , 0.2 , 0.16666667,
0.14285714, 0.125 , 0.11111111, 0.1 , 0.09090909],
[0.33333333, 0.25 , 0.2 , 0.16666667, 0.14285714,
0.125 , 0.11111111, 0.1 , 0.09090909, 0.08333333],
[0.25 , 0.2 , 0.16666667, 0.14285714, 0.125 ,
0.11111111, 0.1 , 0.09090909, 0.08333333, 0.07692308],
[0.2 , 0.16666667, 0.14285714, 0.125 , 0.11111111,
0.1 , 0.09090909, 0.08333333, 0.07692308, 0.07142857],
[0.16666667, 0.14285714, 0.125 , 0.11111111, 0.1 ,
0.09090909, 0.08333333, 0.07692308, 0.07142857, 0.06666667],
[0.14285714, 0.125 , 0.11111111, 0.1 , 0.09090909,
0.08333333, 0.07692308, 0.07142857, 0.06666667, 0.0625 ],
[0.125 , 0.11111111, 0.1 , 0.09090909, 0.08333333,
0.07692308, 0.07142857, 0.06666667, 0.0625 , 0.05882353],
[0.11111111, 0.1 , 0.09090909, 0.08333333, 0.07692308,
0.07142857, 0.06666667, 0.0625 , 0.05882353, 0.05555556],
[0.1 , 0.09090909, 0.08333333, 0.07692308, 0.07142857,
0.06666667, 0.0625 , 0.05882353, 0.05555556, 0.05263158]])
y_train = np.ones(10)
# 创建一系列的缩减系数
alphas = np.logspace(-10,-2,200)
# 用以上一系列的缩减系数来定义一系列的rigde模型
# 定义一个列表用于收集每一次的回归系数
w_list = []
rigde = Ridge(fit_intercept=False)
for alpha in alphas:
rigde.set_params(alpha=alpha)
rigde.fit(x_train,y_train)
w_list.append(rigde.coef_)
w_list
[array([ 2.64506194, -27.60370846, 7.99288361, 133.67548529,
18.04324382, -123.85507 , -175.62005733, -113.78633307,
45.15379234, 274.0230386 ]),
array([ 2.77495268, -30.12097789, 17.90676387, 125.4256614 ,
10.71785186, -122.86634897, -168.64642033, -106.59631559,
46.70073763, 265.22299813]),
array([ 2.89641313, -32.48098438, 27.21874539, 117.65913701,
3.83376846, -121.92540992, -162.07888102, -99.83164383,
48.15084054, 256.93546905]),
array([ 3.009596 , -34.68697138, 35.94234471, 110.36416994,
-2.61959408, -121.03070286, -155.90692527, -93.48139787,
49.50639034, 249.14682045]),
array([ 3.11468953, -36.74289413, 44.09395415, 103.52651692,
-8.65489102, -120.18031137, -150.11799288, -87.53261659,
50.77013506, 241.84083146]),
array([ 3.21190947, -38.65326749, 51.69225511, 97.12991693,
-14.28643421, -119.37200504, -144.69785745, -81.97067415,
51.94511301, 234.99922495]),
array([ 3.30149189, -40.42303219, 58.75770142, 91.15652392,
-19.52984583, -118.60327551, -139.63095685, -76.77966389,
53.03460166, 228.60208673]),
array([ 3.38368582, -42.05741686, 65.31198943, 85.58733247,
-24.4016585 , -117.87139359, -134.90077452, -71.94272906,
54.0420226 , 222.62831057]),
array([ 3.45874753, -43.56183117, 71.37764817, 80.4025176 ,
-28.91903564, -117.17344965, -130.49008719, -67.44236211,
54.97089055, 217.05593107]),
array([ 3.52693448, -44.94175495, 76.97762106, 75.58175003,
-33.09943024, -116.50638646, -126.38126767, -63.26070859,
55.82476096, 211.86247804]),
array([ 3.58850053, -46.20264871, 82.1349192 , 71.10450313,
-36.96038318, -115.86703676, -122.55649672, -59.37975927,
56.60716168, 207.02524377]),
array([ 3.64369182, -47.34987902, 86.8723451 , 66.95025578,
-40.51930745, -115.25212225, -118.99794723, -55.78158192,
57.32155585, 202.52152579]),
array([ 3.69274279, -48.38864607, 91.21221826, 63.09871531,
-43.79328968, -114.65830059, -115.68796058, -52.44847261,
57.97129714, 198.32883386]),
array([ 3.73587319, -49.32393111, 95.1761821 , 59.52996384,
-46.79895092, -114.082172 , -112.60916511, -49.36310183,
58.55960985, 194.42504105]),
array([ 3.77328535, -50.1604467 , 98.78501929, 56.22460694,
-49.55232604, -113.52028761, -109.74460055, -46.50862017,
59.0895629 , 190.78851845]),
array([ 3.80516223, -50.9026046 , 102.05853672, 53.16386363,
-52.06879811, -112.96915014, -107.07776738, -43.86873427,
59.56403236, 187.39822877]),
array([ 3.83166554, -51.55448365, 105.01545591, 50.3296263 ,
-54.36297942, -112.42522698, -104.59271476, -41.4277858 ,
59.98571015, 184.23379904]),
array([ 3.85293445, -52.11980838, 107.67333465, 47.70453898,
-56.44870113, -111.88495469, -102.27405606, -39.17077523,
60.357071 , 181.27556724]),
array([ 3.86908473, -52.60193681, 110.04853443, 45.27199557,
-58.33896185, -111.34472478, -100.10700194, -37.0834019 ,
60.68038155, 178.50460369]),
array([ 3.88020768, -53.0038429 , 112.15615605, 43.01620765,
-60.04592392, -110.8008918 , -98.07737885, -35.1520673 ,
60.9576847 , 175.90273293]),
array([ 3.88637015, -53.32812024, 114.01006926, 40.92214344,
-61.5808884 , -110.24976275, -96.17160943, -33.36388461,
61.19079797, 173.45252422]),
array([ 3.88761396, -53.57697284, 115.622877 , 38.97557705,
-62.95431328, -109.68760382, -94.37672115, -31.70666511,
61.38132252, 171.13727992]),
array([ 3.88395613, -53.75222205, 117.00595084, 37.16304093,
-64.17581927, -109.11062937, -92.68032262, -30.16890405,
61.53063971, 168.94101426]),
array([ 3.87538924, -53.85531327, 118.16945884, 35.47180265,
-65.25421761, -108.51498985, -91.07058274, -28.73976058,
61.63991033, 166.84842981]),
array([ 3.86188178, -53.88732266, 119.12238878, 33.88985255,
-66.19752941, -107.89678662, -89.53620754, -27.40903374,
61.7100896 , 164.84488487]),
array([ 3.84337887, -53.84896862, 119.87258897, 32.40587171,
-67.013012 , -107.25206421, -88.06642446, -26.167136 ,
61.74193059, 162.916368 ]),
array([ 3.81980348, -53.74063216, 120.42683714, 31.00918391,
-67.70720436, -106.57680708, -86.65094475, -25.00506088,
61.73599183, 161.04946072]),
array([ 3.79105752, -53.56237115, 120.79087986, 29.68974112,
-68.28595502, -105.86695287, -85.27994892, -23.91435811,
61.69264982, 159.23131544]),
array([ 3.75702352, -53.31394529, 120.96950675, 28.4380848 ,
-68.75447239, -105.11839297, -83.94406279, -22.88710251,
61.61211143, 157.44962701]),
array([ 3.71756654, -52.99484069, 120.96661617, 27.245319 ,
-69.11735761, -104.32699369, -82.63434255, -21.91586744,
61.49442847, 155.6926156 ]),
array([ 3.67253648, -52.60430183, 120.78530056, 26.10307889,
-69.37866046, -103.48860616, -81.34225707, -20.99369753,
61.33950993, 153.94901252]),
array([ 3.62177083, -52.14136486, 120.42792595, 25.00351425,
-69.5419232 , -102.59909559, -80.05968493, -20.11408734,
61.14714792, 152.20805246]),
array([ 3.56509784, -51.60489814, 119.89623134, 23.93926266,
-69.6102345 , -101.65436586, -78.77890784, -19.27095884,
60.91702811, 150.45947987]),
array([ 3.50234022, -50.99364742, 119.19143216, 22.90343432,
-69.58628291, -100.65039876, -77.49261708, -18.45864347,
60.64876218, 148.69355412]),
array([ 3.43331926, -50.30628524, 118.31432875, 21.88960498,
-69.47241815, -99.58329305, -76.19392307, -17.67186918,
60.34190527, 146.90107659]),
array([ 3.35785961, -49.54146725, 117.26542999, 20.89180372,
-69.2707048 , -98.4493185 , -74.87637583, -16.90574754,
59.99598908, 145.07341769]),
array([ 3.27579443, -48.69789333, 116.04508098, 19.9045173 ,
-68.98299281, -97.24497144, -73.53398708, -16.15576756,
59.61055048, 143.20255929]),
array([ 3.18697118, -47.77437477, 114.65360463, 18.92268504,
-68.61097672, -95.96703818, -72.16126505, -15.4177916 ,
59.18516452, 141.28114589]),
array([ 3.09125779, -46.76990544, 113.09144555, 17.94171247,
-68.15627104, -94.61266516, -70.75325056, -14.68805326,
58.71947844, 139.30254415]),
array([ 2.9885492 , -45.68373707, 111.35932194, 16.95747966,
-67.62047441, -93.1794361 , -69.30556317, -13.96316182,
58.21325083, 137.26091109]),
array([ 2.87877422, -44.51545747, 109.45838188, 15.96635617,
-67.00524252, -91.66545094, -67.81444969, -13.24010458,
57.66638811, 135.15126785]),
array([ 2.76190244, -43.2650699 , 107.39036039, 14.9652177 ,
-66.31235851, -90.06940744, -66.27683471, -12.51625554,
57.07898383, 132.96957802]),
array([ 2.63795117, -41.93307133, 105.15773344, 13.95146454,
-65.5438021 , -88.39068246, -64.69037373, -11.78938193,
56.45135668, 130.71282607]),
array([ 2.50699203, -40.52052767, 102.76386559, 12.92303929,
-64.70181449, -86.62941058, -63.05350362, -11.05765289,
55.7840866 , 128.3790945 ]),
array([ 2.36915707, -39.02914331, 100.21314689, 11.87844275,
-63.78895935, -84.78655568, -61.36548973, -10.31964645,
55.07804892, 125.96763373]),
array([ 2.224644 , -37.46132053, 97.51110882, 10.81675016,
-62.80817578, -82.86397382, -59.6264656 , -9.5743555 ,
54.33444361, 123.47892206]),
array([ 2.07372058, -35.82020928, 94.66452565, 9.73761498,
-61.76282256, -80.86446122, -57.83746169, -8.8211898 ,
53.55481799, 120.91471082]),
array([ 1.91672733, -34.10973933, 91.68147891, 8.64127502,
-60.65671056, -78.79178716, -56.00042247, -8.05997466,
52.74108418, 118.27804959]),
array([ 1.75407894, -32.33463637, 88.57139534, 7.52854386,
-59.49412102, -76.65070532, -54.11820729, -7.29094417,
51.89552612, 115.57328936]),
array([ 1.58626365, -30.50041707, 85.34504457, 6.40079611,
-58.27981062, -74.44694236, -52.19457148, -6.51472839,
51.02079724, 112.80605821]),
array([ 1.41384083, -28.61336284, 82.01449868, 5.2599403 ,
-57.01899877, -72.1871618 , -50.23412933, -5.73233357,
50.11990861, 109.98320775]),
array([ 1.23743625, -26.68046964, 78.59304663, 4.1083831 ,
-55.71733825, -69.87890136, -48.2422952 , -4.9451158 ,
49.19620478, 107.1127297 ]),
array([ 1.05773558, -24.70937637, 75.09507103, 2.94897949,
-54.38086942, -67.53048436, -46.22520387, -4.15474698,
48.25332974, 104.2036414 ]),
array([ 0.87547575, -22.70827065, 71.53588228, 1.78497521,
-53.01595904, -65.15090604, -44.18961066, -3.36317426,
47.29518161, 101.26584258]),
array([ 0.69143462, -20.68577568, 67.93151858, 0.61993839,
-51.62922342, -62.74969861, -42.14277445, -2.57257357,
46.32585875, 98.30994656]),
array([ 0.50641923, -18.65082147, 64.29851727, -0.54231608,
-50.22744165, -60.33677707, -40.09232574, -1.78529785,
45.34959808, 95.34708976]),
array([ 0.32125303, -16.61250416, 60.65366423, -1.69780775,
-48.81745861, -57.92227257, -38.04612435, -1.00382225,
44.37070747, 92.38872647]),
array([ 0.1367625 , -14.57993872, 57.01373061, -2.84247404,
-47.40608386, -55.51635799, -36.01211131, -0.23068726,
43.3934954 , 89.4464146 ]),
array([-4.62363944e-02, -1.25621107e+01, 5.33952080e+01, -3.97225481e+00,
-4.59999887e+01, -5.31290721e+01, -3.39981601e+01, 5.31558052e-01,
4.24221995e+01, 8.65316007e+01]),
array([ -0.22695122, -10.56773191, 49.8140496 , -5.08317304,
-44.60560624, -50.77014909, -32.01193226, 1.28041087,
41.46091743, 83.65541134]),
array([ -0.40462537, -8.60510538, 46.28542887, -6.17141108,
-43.22903739, -48.44885861, -30.06074258, 2.01346585,
40.5135417 , 80.82845754]),
array([ -0.578549 , -6.68200471, 42.82352252, -7.23337841,
-41.87596684, -46.17386255, -28.15143781, 2.72846112,
39.58370166, 78.06065966]),
array([ -0.74806836, -4.80557066, 39.44132542, -8.2657701 ,
-40.55159098, -43.95309231, -26.29029291, 3.4233178 ,
38.67471336, 75.36109688]),
array([ -0.91259328, -2.98222861, 36.15050235, -9.26561389,
-39.26056035, -41.79364985, -24.48292717, 4.09617204,
37.78953946, 72.73788568]),
array([ -1.0716026 , -1.21762809, 32.96127947, -10.23030549,
-38.00693754, -39.70173466, -22.73424171, 4.74539916,
36.93075951, 70.19808926]),
array([ -1.22464753, 0.48339507, 29.88237626, -11.15763131,
-36.79417071, -37.68259709, -21.04837926, 5.36962958,
36.10055103, 67.74765884]),
array([ -1.37135302, 2.11683428, 26.92097769, -12.04577959,
-35.625083 , -35.74051738, -19.42870512, 5.96775683,
35.30068112, 65.39140585]),
array([ -1.5114172 , 3.67950699, 24.08274332, -12.89334003,
-34.50187618, -33.87880879, -17.87780834, 6.53893813,
34.53250781, 63.1330033 ]),
array([ -1.64460911, 5.16903087, 21.37184999, -13.69929365,
-33.42614734, -32.09984247, -16.39752091, 7.08258838,
33.79699041, 60.97501291]),
array([ -1.77076501, 6.58378401, 18.79106293, -14.46299383,
-32.39891675, -30.40509061, -14.98895228, 7.5983683 ,
33.09470718, 58.91893462]),
array([ -1.88978349, 7.92285223, 16.34183006, -15.18414047,
-31.42066461, -28.79518471, -13.65253645, 8.08616802,
32.42587915, 56.96527409]),
array([ -2.00161963, 9.18596647, 14.02439393, -15.86274874,
-30.49137463, -27.26998512, -12.38808857, 8.54608706,
31.79039865, 55.11362395]),
array([ -2.10627861, 10.37343342, 11.83791556, -16.49911403,
-29.6105822 , -25.82865849, -11.1948681 , 8.97841186,
31.18786081, 53.36275442]),
array([ -2.20380891, 11.48606216, 9.78060537, -17.09377467,
-28.77742516, -24.4697597 , -10.07164588, 9.38359174,
30.61759704, 51.71070954]),
array([ -2.29429532, 12.52508939, 7.84985637, -17.64747383,
-27.99069542, -23.19131538, -9.01677257, 9.76221431,
30.07870896, 50.15490529]),
array([ -2.37785214, 13.49210546, 6.04237579, -18.16112154,
-27.24888983, -21.99090645, -8.02824639, 10.11498107,
29.57010192, 48.6922266 ]),
array([ -2.45461647, 14.38898288, 4.35431198, -18.63575804,
-26.550259 , -20.86574764, -7.10377845, 10.44268376,
29.09051726, 47.31912095]),
array([ -2.52474193, 15.21780883, 2.78137399, -19.07251901,
-25.89285334, -19.81276238, -6.24085443, 10.74618199,
28.63856257, 46.0316863 ]),
array([ -2.58839278, 15.9808226 , 1.31894213, -19.47260333,
-25.27456533, -18.82865195, -5.43679142, 11.02638255,
28.21273957, 44.82575234]),
array([-2.64573860e+00, 1.66803586e+01, -3.78317455e-02, -1.98372437e+01,
-2.46931677e+01, -1.79099580e+01, -4.68878959e+00, 1.12842206e+01,
2.78114692e+01, 4.36969539e+01]),
array([ -2.69694947, 17.31879553, -1.29393475, -20.16768012,
-24.14634752, -17.05311846, -3.99397806, 11.52064283,
27.43311416, 42.64079635]),
array([ -2.74219171, 17.89851132, -2.45441458, -20.46513713,
-23.63173553, -16.25451577, -3.34945515, 11.73659269,
27.07599784, 41.65271217]),
array([ -2.78162417, 18.4218446 , -3.52431199, -20.73080331,
-23.1469317 , -15.51051904, -2.75232281, 11.93299774,
26.73842136, 40.72810999]),
array([ -2.81539514, 18.89106166, -4.5086037 , -20.96581452,
-22.68952644, -14.81751901, -2.19971578, 12.11075884,
26.41867724, 39.86241542]),
array([ -2.84363966, 19.30832909, -5.41215503, -21.17123968,
-22.25711819, -14.17195716, -1.68882556, 12.27074135,
26.11506088, 39.05110479]),
array([ -2.86647741, 19.67569163, -6.23968156, -21.34806944,
-21.84732765, -13.57034914, -1.21691978, 12.41376789,
25.8258798 , 38.28973231]),
array([ -2.88401105, 19.99505476, -6.99571901, -21.49720724,
-21.45780903, -13.0093032 , -0.78135738, 12.54061289,
25.54946091, 37.5739514 ]),
array([ -2.896325 , 20.26817172, -7.68460029, -21.61946286,
-21.08625876, -12.48553425, -0.37960005, 12.65199839,
25.28415614, 36.89953088]),
array([-2.90348459e+00, 2.04966345e+01, -8.31043910e+00, -2.17155480e+01,
-2.07304221e+01, -1.19958740e+01, -9.22047572e-03, 1.27485913e+01,
2.50283467e+01, 3.62623670e+01]),
array([ -2.90553568, 20.68186842, -8.87711913, -21.78607419,
-20.38809796, -11.5372783 , 0.3320923 , 12.83100182,
24.78044643, 35.65849147]),
array([ -2.90250466, 20.82513047, -9.38828806, -21.83155194,
-20.05714255, -11.10683105, 0.64652978, 12.89978275,
24.53890392, 35.08407748]),
array([ -2.89439873, 20.92751018, -9.84735586, -21.85239256,
-19.73547203, -10.70174702, 0.93616287, 12.95543005,
24.30220479, 34.53544254]),
array([ -2.88120667, 20.98993384, -10.25749654, -21.84891103,
-19.42106466, -10.3193723 , 1.20294244, 12.99838404,
24.06887319, 34.00905054]),
array([ -2.86289995, 21.01317139, -10.62165305, -21.82133089,
-19.11196276, -9.95718393, 1.44870106, 13.02903159,
23.83747364, 33.50151258]),
array([ -2.83943421, 20.99784593, -10.94254465, -21.7697907 ,
-18.80627485, -9.61278865, 1.67515578, 13.04770894,
23.60661296, 33.00958749]),
array([ -2.81075111, 20.94444592, -11.2226765 , -21.69435221,
-18.50217819, -9.28392137, 1.88391162, 13.05470534,
23.37494278, 32.53018238]),
array([ -2.7767807 , 20.85333988, -11.46435116, -21.59501022,
-18.19792204, -8.9684434 , 2.07646565, 13.05026732,
23.14116268, 32.06035382]),
array([ -2.73744412, 20.72479369, -11.66968152, -21.47170417,
-17.89183189, -8.66434103, 2.25421137, 13.03460373,
22.9040242 , 31.59730988]),
array([ -2.69265675, 20.55899046, -11.84060521, -21.32433148,
-17.58231477, -8.36972436, 2.41844331, 13.00789137,
22.66233581, 31.13841334]),
array([ -2.64233184, 20.35605284, -11.97890007, -21.15276259,
-17.26786588, -8.08282686, 2.57036172, 12.97028126,
22.41496901, 30.68118647]),
array([ -2.58638454, 20.11606773, -12.08620058, -20.95685773,
-16.9470766 , -7.80200552, 2.7110772 , 12.92190556,
22.16086558, 30.22331732]),
array([ -2.52473639, 19.83911336, -12.1640151 , -20.73648529,
-16.61864383, -7.52574193, 2.84161527, 12.86288503,
21.89904598, 29.76266775]),
array([ -2.45732012, 19.52528828, -12.21374366, -20.49154165,
-16.28138076, -7.25264403, 2.96292075, 12.79333698,
21.6286189 , 29.29728316]),
array([ -2.38408488, 19.17474219, -12.23669624, -20.22197221,
-15.93422876, -6.98144874, 3.07586193, 12.71338369,
21.34879176, 28.82540364]),
array([ -2.30500161, 18.78770796, -12.23411114, -19.92779345,
-15.57627034, -6.7110253 , 3.1812346 , 12.62316108,
21.05888208, 28.34547658]),
array([ -2.22006862, 18.36453428, -12.20717338, -19.60911537,
-15.20674272, -6.44037901, 3.27976572, 12.52282748,
20.75832931, 27.85616998]),
array([ -2.12931719, 17.90571836, -12.15703271, -19.26616394,
-14.82505163, -6.16865542, 3.37211696, 12.41257247,
20.44670693, 27.35638635]),
array([ -2.03281693, 17.41193752, -12.08482095, -18.89930292,
-14.43078488, -5.89514442, 3.45888803, 12.29262532,
20.12373416, 26.84527631]),
array([ -1.93068091, 16.88407888, -11.99166828, -18.50905431,
-14.02372493, -5.61928414, 3.54061976, 12.16326297,
19.78928696, 26.32225124]),
array([ -1.82307017, 16.32326597, -11.87871808, -18.09611652,
-13.60386 , -5.34066406, 3.61779714, 12.02481716,
19.44340769, 25.78699425]),
array([ -1.7101975 , 15.73088097, -11.74713996, -17.66137963,
-13.17139278, -5.05902707, 3.69085217, 11.87768038,
19.08631277, 25.23946846]),
array([ -1.59233025, 15.10858153, -11.59814046, -17.20593667,
-12.72674614, -4.77427001, 3.76016668, 11.72231052,
18.71839776, 24.67992183]),
array([ -1.46979191, 14.45831102, -11.43297114, -16.73109027,
-12.27056514, -4.48644237, 3.82607521, 11.55923367,
18.34023932, 24.10888764]),
array([ -1.34296239, 13.78230107, -11.25293367, -16.23835386,
-11.80371459, -4.19574264, 3.88886777, 11.389045 ,
17.9525935 , 23.52717987]),
array([ -1.21227675, 13.08306589, -11.05938165, -15.72944698,
-11.32727192, -3.90251223, 3.94879283, 11.2124075 ,
17.55638988, 22.93588303]),
array([ -1.0782223 , 12.36338762, -10.85371909, -15.20628422,
-10.8425148 , -3.60722666, 4.00606028, 11.0300484 ,
17.15272151, 22.3363358 ]),
array([ -0.94133421, 11.62629291, -10.63739541, -14.67095784,
-10.35090366, -3.31048401, 4.06084456, 10.84275327,
16.74283038, 21.73010874]),
array([ -0.80218939, 10.87502069, -10.41189721, -14.12571428,
-9.85405908, -3.01299073, 4.11328787, 10.65135792,
16.32808865, 21.11897598]),
array([ -0.66139914, 10.11298227, -10.17873696, -13.57292499,
-9.35373469, -2.71554499, 4.16350347, 10.45673825,
15.90997606, 20.50488155]),
array([ -0.51960038, 9.3437146 , -9.93943916, -13.01505253,
-8.85178613, -2.41901805, 4.21157908, 10.25979839,
15.49005402, 19.88990115]),
array([ -0.37744619, 8.57082837, -9.69552439, -12.45461294,
-8.35013702, -2.12433407, 4.2575802 , 10.06145738,
15.06993722, 19.27620041]),
array([ -0.23559553, 7.7979529 , -9.44849209, -11.89413579,
-7.85074315, -1.83244895, 4.30155354, 9.86263514,
14.65126361, 18.6659911 ]),
array([ -0.09470292, 7.02867958, -9.19980253, -11.33612329,
-7.35555597, -1.54432906, 4.34353026, 9.66423794,
14.23566387, 18.0614865 ]),
array([ 0.04459182, 6.26650631, -8.95085903, -10.78300998,
-6.86648676, -1.26093022, 4.38352905, 9.46714416,
13.8247314 , 17.46485774]),
array([ 0.18167324, 5.51478462, -8.7029909 , -10.23712448,
-6.38537267, -0.98317803, 4.42155907, 9.27219066,
13.41999382, 16.87819231]),
array([ 0.31595904, 4.77667171, -8.45743795, -9.70065466, -5.91394578,
-0.71194981, 4.45762257, 9.0801604 , 13.02288699, 16.30345633]),
array([ 0.44690772, 4.05508866, -8.21533696, -9.17561722, -5.45380609,
-0.44805886, 4.49171718, 8.89177165, 12.63473234, 15.74246144]),
array([ 0.57402479, 3.35268627, -7.97771076, -8.66383278, -5.00639925,
-0.19224143, 4.52383789, 8.70766903, 12.25671808, 15.1968375 ]),
array([ 0.69686746, 2.67181923, -7.74545998, -8.16690684, -4.57299943,
0.05485337, 4.55397868, 8.52841676, 11.8898847 , 14.6680113 ]),
array([ 0.81504774, 2.01452887, -7.51935781, -7.68621693, -4.15469757,
0.29267067, 4.58213369, 8.35449397, 11.53511503, 14.15719192]),
array([ 0.92823389, 1.38253472, -7.30004758, -7.22290599, -3.75239498,
0.52075382, 4.60829809, 8.18629234, 11.19312863, 13.66536229]),
array([ 1.03615044, 0.77723408, -7.08804316, -6.77788153, -3.36680198,
0.73874558, 4.63246862, 8.02411567, 10.86448052, 13.19327698]),
array([ 1.13857674, 0.19970924, -6.88373184, -6.35182017, -2.99844121,
0.94638661, 4.65464375, 7.86818143, 10.54956365, 12.74146542]),
array([ 1.23534437, -0.34925898, -6.68737941, -5.94517678, -2.64765488,
1.14351179, 4.67482362, 7.71862394, 10.24861484, 12.31023996]),
array([ 1.32633355, -0.86917221, -6.49913704, -5.55819742, -2.3146154 ,
1.33004458, 4.6930097 , 7.57549881, 9.96172331, 11.89970791]),
array([ 1.41146878, -1.35979135, -6.31904949, -5.19093538, -1.99933855,
1.50598985, 4.7092043 , 7.43878857, 9.68884154, 11.5097867 ]),
array([ 1.49071386, -1.82111047, -6.14706437, -4.84326918, -1.70169862,
1.67142578, 4.72340989, 7.30840896, 9.42979762, 11.14022128]),
array([ 1.56406669, -2.25332877, -5.98304195, -4.51492206, -1.42144461,
1.82649496, 4.73562841, 7.1842157 , 9.18430852, 10.79060296]),
array([ 1.63155382, -2.6568218 , -5.82676527, -4.20548194, -1.15821702,
1.97139519, 4.74586038, 7.06601149, 8.95199393, 10.46038887]),
array([ 1.69322507, -3.03211271, -5.67795009, -3.91442145, -0.91156468,
2.1063703 , 4.75410413, 6.95355301, 8.73239001, 10.14892156]),
array([ 1.74914827, -3.37984437, -5.53625464, -3.64111728, -0.68096104,
2.23170113, 4.76035494, 6.84655767, 8.52496283, 9.85544793]),
array([ 1.79940431, -3.70075294, -5.40128886, -3.3848687 , -0.46581975,
2.34769707, 4.76460431, 6.74471008, 8.32912112, 9.5791374 ]),
array([ 1.84408257, -3.99564321, -5.27262291, -3.14491466, -0.26550905,
2.45468806, 4.76683916, 6.64766808, 8.14422818, 9.31909872]),
array([ 1.88327671, -4.26536633, -5.14979513, -2.92044946, -0.07936501,
2.55301739, 4.76704126, 6.55506823, 7.96961274, 9.0743954 ]),
array([ 1.91708101, -4.51079978, -5.03231911, -2.71063687, 0.09329674,
2.64303526, 4.76518668, 6.46653079, 7.80457879, 8.84405956]),
array([ 1.94558719, -4.73282992, -4.91969002, -2.51462251, 0.25317056,
2.72509312, 4.76124534, 6.38166421, 7.64841421, 8.62710418]),
array([ 1.96888171, -4.93233703, -4.8113903 , -2.33154465, 0.40095183,
2.79953886, 4.75518075, 6.300069 , 7.50039845, 8.42253385]),
array([ 1.98704364, -5.11018283, -4.70689453, -2.16054346, 0.5373294 ,
2.86671272, 4.74694985, 6.22134124, 7.35980905, 8.22935391]),
array([ 2.00014305, -5.26720038, -4.60567385, -2.00076878, 0.6629792 ,
2.92694408, 4.73650304, 6.14507558, 7.22592731, 8.04657834]),
array([ 2.00823987, -5.40418626, -4.50719972, -1.85138658, 0.77855916,
2.98054886, 4.72378431, 6.07086792, 7.09804314, 7.87323638]),
array([ 2.01138327, -5.52189489, -4.41094739, -1.71158423, 0.88470514,
3.02782766, 4.70873165, 5.99831782, 6.97545911, 7.70837803]),
array([ 2.00961158, -5.62103482, -4.31639895, -1.58057469, 0.98202794,
3.0690645 , 4.69127758, 5.92703062, 6.85749406, 7.55107875]),
array([ 2.00295267, -5.70226687, -4.22304623, -1.45759988, 1.071111 ,
3.10452616, 4.67134992, 5.85661951, 6.74348609, 7.40044331]),
array([ 1.99142474, -5.76620405, -4.13039353, -1.34193325, 1.15250906,
3.13446195, 4.64887278, 5.78670747, 6.63279531, 7.25560916]),
array([ 1.97503769, -5.81341299, -4.03796036, -1.23288176, 1.22674728,
3.1591041 , 4.6237678 , 5.71692918, 6.52480635, 7.11574936]),
array([ 1.95379487, -5.84441684, -3.94528423, -1.12978738, 1.29432103,
3.17866844, 4.59595559, 5.64693307, 6.41893072, 6.98007517]),
array([ 1.92769531, -5.85969954, -3.85192352, -1.03202823, 1.35569607,
3.19335553, 4.56535747, 5.57638338, 6.31460917, 6.84783854]),
array([ 1.89673635, -5.85971126, -3.75746059, -0.93901943, 1.41130915,
3.20335214, 4.53189746, 5.50496245, 6.21131402, 6.71833457]),
array([ 1.86091666, -5.84487498, -3.66150503, -0.85021378, 1.46156892,
3.20883299, 4.49550442, 5.43237307, 6.10855171, 6.59090391]),
array([ 1.82023966, -5.81559389, -3.56369722, -0.7651023 , 1.50685703,
3.20996278, 4.45611458, 5.35834115, 6.0058653 , 6.46493531]),
array([ 1.77471723, -5.77225974, -3.46371204, -0.6832147 , 1.54752949,
3.20689841, 4.41367408, 5.28261842, 5.90283732, 6.33986824]),
array([ 1.72437362, -5.71526165, -3.36126278, -0.60411985, 1.58391809,
3.1997914 , 4.36814185, 5.20498535, 5.79909249, 6.21519558]),
array([ 1.66924958, -5.64499541, -3.25610516, -0.52742609, 1.61633198,
3.18879032, 4.31949241, 5.12525413, 5.69430076, 6.09046645]),
array([ 1.60940656, -5.56187291, -3.14804139, -0.45278166, 1.64505924,
3.17404338, 4.26771887, 5.04327168, 5.58818014, 5.96528893]),
array([ 1.54493075, -5.4663315 , -3.0369241 , -0.37987495, 1.67036852,
3.15570091, 4.2128357 , 4.95892255, 5.48049961, 5.83933274]),
array([ 1.47593706, -5.3588429 , -2.92265996, -0.30843471, 1.69251061,
3.1339178 , 4.15488152, 4.87213165, 5.37108163, 5.71233164]),
array([ 1.40257262, -5.23992144, -2.80521299, -0.23823007, 1.71172001,
3.10885566, 4.09392144, 4.78286665, 5.25980447, 5.58408549]),
array([ 1.32501984, -5.11013124, -2.68460712, -0.16907037, 1.72821634,
3.08068481, 4.03004919, 4.69113988, 5.14660386, 5.45446169]),
array([ 1.2434988 , -4.97009198, -2.56092801, -0.10080473, 1.74220577,
3.0495859 , 3.96338853, 4.59700966, 5.03147403, 5.32339598]),
array([ 1.15826874, -4.820483 , -2.43432384, -0.03332128, 1.75388218,
3.01575104, 3.89409424, 4.50058079, 4.91446793, 5.19089221]),
array([ 1.06962867, -4.66204542, -2.30500492, 0.03345406, 1.76342834,
2.9793845 , 3.82235225, 4.40200429, 4.79569639, 5.05702123]),
array([ 0.97791677, -4.495582 , -2.17324197, 0.09955917, 1.77101676,
2.94070283, 3.74837893, 4.30147595, 4.67532621, 4.92191842]),
array([ 0.88350881, -4.32195468, -2.03936309, 0.16499783, 1.77681056,
2.89993436, 3.67241962, 4.19923402, 4.55357708, 4.78578015]),
array([ 0.78681517, -4.14207971, -1.90374917, 0.22974233, 1.78096397,
2.85731814, 3.59474617, 4.09555567, 4.43071728, 4.64885874]),
array([ 0.68827691, -3.95692041, -1.76682799, 0.29373655, 1.78362288,
2.81310227, 3.51565362, 3.99075246, 4.30705816, 4.51145637]),
array([ 0.58836062, -3.76747774, -1.62906703, 0.35689942, 1.78492503,
2.76754167, 3.4354561 , 3.8851648 , 4.18294763, 4.37391768]),
array([ 0.48755243, -3.57477903, -1.49096509, 0.41912883, 1.78500023,
2.72089538, 3.35448206, 3.77915564, 4.0587626 , 4.23662147]),
array([ 0.38635124, -3.37986509, -1.35304315, 0.48030576, 1.78397034,
2.6734235 , 3.27306883, 3.67310335, 3.93490077, 4.09997157]),
array([ 0.28526153, -3.1837764 , -1.21583459, 0.54029861, 1.78194919,
2.62538379, 3.19155695, 3.56739435, 3.81177192, 3.96438722]),
array([ 0.18478584, -2.9875387 , -1.07987519, 0.59896754, 1.77904246,
2.57702826, 3.11028421, 3.46241538, 3.68978898, 3.8302933 ]),
array([ 0.08541742, -2.79214876, -0.94569322, 0.65616876, 1.77534745,
2.52859962, 3.02957985, 3.35854599, 3.56935926, 3.69811064]),
array([-0.01236685, -2.59856059, -0.81380001, 0.71175861, 1.77095291,
2.48032795, 2.94975889, 3.25615126, 3.45087602, 3.5682468 ]),
array([-0.10811296, -2.40767288, -0.68468129, 0.76559721, 1.7659388 ,
2.43242766, 2.871117 , 3.15557513, 3.33471076, 3.44108753]),
array([-0.20139538, -2.22031801, -0.55878952, 0.81755184, 1.76037616,
2.38509471, 2.79392594, 3.05713458, 3.22120636, 3.31698934]),
array([-0.2918218 , -2.03725282, -0.43653755, 0.86749962, 1.75432698,
2.3385044 , 2.71842971, 2.9611146 , 3.11067135, 3.19627309]),
array([-0.37903668, -1.85915169, -0.31829364, 0.91532982, 1.74784414,
2.29280961, 2.64484163, 2.86776438, 3.00337537, 3.07921905]),
array([-0.46272366, -1.68660173, -0.20437791, 0.96094537, 1.74097146,
2.24813963, 2.57334219, 2.77729451, 2.89954591, 2.96606323]),
array([-0.54260689, -1.5201003 , -0.09506038, 1.00426398, 1.73374376,
2.20459946, 2.50407783, 2.68987532, 2.7993663 , 2.85699515]),
array([-0.61845113, -1.3600547 , 0.00943971, 1.04521856, 1.72618711,
2.16226977, 2.43716062, 2.60563629, 2.70297496, 2.75215696]),
array([-0.69006096, -1.20678384, 0.10895321, 1.08375717, 1.71831901,
2.12120723, 2.37266858, 2.52466642, 2.61046574, 2.65164368]),
array([-0.75727902, -1.06052165, 0.20335925, 1.11984257, 1.71014876,
2.08144531, 2.31064685, 2.44701553, 2.52188926, 2.55550466]),
array([-0.81998357, -0.92142195, 0.29258249, 1.1534513 , 1.70167783,
2.04299548, 2.25110931, 2.37269619, 2.43725511, 2.46374583]),
array([-0.87808548, -0.78956448, 0.37658963, 1.18457256, 1.69290026,
2.00584867, 2.19404074, 2.30168635, 2.35653467, 2.37633272]),
array([-0.93152482, -0.66496173, 0.45538538, 1.21320679, 1.68380316,
1.96997692, 2.13939929, 2.2339323 , 2.27966448, 2.293194 ]),
array([-0.98026726, -0.54756636, 0.52900805, 1.23936422, 1.67436718,
1.93533523, 2.08711921, 2.16935195, 2.20654989, 2.21422547]),
array([-1.02430031, -0.43727877, 0.59752503, 1.26306327, 1.66456703,
1.9018635 , 2.03711365, 2.10783835, 2.13706893, 2.13929408]),
array([-1.06362971, -0.33395478, 0.66102823, 1.28432898, 1.65437206,
1.86948844, 1.98927763, 2.04926318, 2.07107627, 2.06824224]),
array([-1.09827592, -0.23741307, 0.71962965, 1.30319155, 1.64374677,
1.83812553, 1.9434908 , 1.99348018, 2.00840704, 2.00089187]),
array([-1.12827095, -0.14744222, 0.77345723, 1.31968494, 1.63265142,
1.80768091, 1.89962033, 1.94032861, 1.94888066, 1.93704843]),
array([-1.15365551, -0.06380733, 0.82265094, 1.33384561, 1.62104261,
1.77805326, 1.85752347, 1.88963634, 1.89230434, 1.87650476])]
plt.figure(figsize=(12,9))
axes = plt.subplot(111)
axes.plot(alphas,w_list)
axes.set_xscale("log")
# 创建一系列的缩减系数
alphas = np.logspace(-2,5,200)
# 用以上一系列的缩减系数来定义一系列的rigde模型
# 定义一个列表用于收集每一次的回归系数
w_list = []
rigde = Ridge(fit_intercept=False)
for alpha in alphas:
rigde.set_params(alpha=alpha)
rigde.fit(x_train,y_train)
w_list.append(rigde.coef_)
plt.figure(figsize=(12,9))
axes = plt.subplot(111)
axes.plot(alphas,w_list)
axes.set_xscale("log")
三、lasso回归
1、原理
【拉格朗日乘数法】
对于参数w增加一个限定条件,能到达和岭回归一样的效果:
在lambda足够小的时候,一些系数会因此被迫缩减到0
2、实例
from sklearn.linear_model import Lasso
# 创建一系列的缩减系数
alphas = np.logspace(-10,-2,200)
# 用以上一系列的缩减系数来定义一系列的rigde模型
# 定义一个列表用于收集每一次的回归系数
w_list = []
lasso = Lasso(fit_intercept=False)
for alpha in alphas:
lasso.set_params(alpha=alpha)
lasso.fit(x_train,y_train)
w_list.append(lasso.coef_)
plt.figure(figsize=(12,9))
axes = plt.subplot(111)
axes.plot(alphas,w_list)
axes.set_xscale("log")
# 创建一系列的缩减系数
alphas = np.logspace(-2,5,200)
# 用以上一系列的缩减系数来定义一系列的rigde模型
# 定义一个列表用于收集每一次的回归系数
w_list = []
lasso = Lasso(fit_intercept=False)
for alpha in alphas:
lasso.set_params(alpha=alpha)
lasso.fit(x_train,y_train)
w_list.append(lasso.coef_)
plt.figure(figsize=(12,9))
axes = plt.subplot(111)
axes.plot(alphas,w_list)
axes.set_xscale("log")
四、普通线性回归、岭回归与lasso回归比较
使用numpy创建一些数据
train = np.random.randn(50,200)
# 自定义一些回归系数
coef = np.random.randn(200)
# 为了方便,我们设置其中190个为0
ind = np.arange(200)
np.random.shuffle(ind)
coef[ind[:190]] = 0
coef
array([ 0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0.08636003, 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , -0.22477679,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , -1.68022154, 0. , 0. , -0.98508379,
0. , 0. , 0. , 0. , 0. ,
0. , -0.49212469, 0. , -0.38555283, 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , -1.59506153, 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , -1.40028956, -0.02936659, 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 3.23805861, 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ])
# 自定义一个回归方程
y = np.dot(train,coef)
# 对y加干扰
y += np.random.normal(200)*0.01
x_train,x_test,y_train,y_test = train_test_split(train,y,test_size=0.3)
分别创建三种模型,并且训练预测
lgr = LinearRegression()
lgr.fit(x_train,y_train)
lgr.score(x_test,y_test)
0.011130491964873257
r = Ridge(0.000001)
r.fit(x_train,y_train)
r.score(x_test,y_test)
0.0008063155766956376
l = Lasso(0.1)
l.fit(x_train,y_train)
l.score(x_test,y_test)
0.8218723717861744
用图像来对比显示三种模型预测的回归系数和真实回归系数
plt.figure(figsize=(12,9))
axes = plt.subplot(221)
axes.set_title("True")
axes.plot(coef)
axes = plt.subplot(222)
axes.set_title("lgr")
axes.plot(lgr.coef_)
axes = plt.subplot(223)
axes.set_title("rigde")
axes.plot(r.coef_)
axes = plt.subplot(224)
axes.set_title("Lasso")
axes.plot(l.coef_)
[<matplotlib.lines.Line2D at 0x21aa7ec7080>]
五、练习
1、使用多种方法对boston数据集进行回归,画出回归图像,并比较多种回归方法的效果
from sklearn.datasets import load_boston boston = load_boston() x = boston.data y = boston.target
2、预测鲍鱼的年龄
import pandas as pd
pd.read_csv("../data/abalone.txt",sep="\t",header=None)
0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | |
---|---|---|---|---|---|---|---|---|---|
0 | 1 | 0.455 | 0.365 | 0.095 | 0.5140 | 0.2245 | 0.1010 | 0.1500 | 15 |
1 | 1 | 0.350 | 0.265 | 0.090 | 0.2255 | 0.0995 | 0.0485 | 0.0700 | 7 |
2 | -1 | 0.530 | 0.420 | 0.135 | 0.6770 | 0.2565 | 0.1415 | 0.2100 | 9 |
3 | 1 | 0.440 | 0.365 | 0.125 | 0.5160 | 0.2155 | 0.1140 | 0.1550 | 10 |
4 | 0 | 0.330 | 0.255 | 0.080 | 0.2050 | 0.0895 | 0.0395 | 0.0550 | 7 |
5 | 0 | 0.425 | 0.300 | 0.095 | 0.3515 | 0.1410 | 0.0775 | 0.1200 | 8 |
6 | -1 | 0.530 | 0.415 | 0.150 | 0.7775 | 0.2370 | 0.1415 | 0.3300 | 20 |
7 | -1 | 0.545 | 0.425 | 0.125 | 0.7680 | 0.2940 | 0.1495 | 0.2600 | 16 |
8 | 1 | 0.475 | 0.370 | 0.125 | 0.5095 | 0.2165 | 0.1125 | 0.1650 | 9 |
9 | -1 | 0.550 | 0.440 | 0.150 | 0.8945 | 0.3145 | 0.1510 | 0.3200 | 19 |
10 | -1 | 0.525 | 0.380 | 0.140 | 0.6065 | 0.1940 | 0.1475 | 0.2100 | 14 |
11 | 1 | 0.430 | 0.350 | 0.110 | 0.4060 | 0.1675 | 0.0810 | 0.1350 | 10 |
12 | 1 | 0.490 | 0.380 | 0.135 | 0.5415 | 0.2175 | 0.0950 | 0.1900 | 11 |
13 | -1 | 0.535 | 0.405 | 0.145 | 0.6845 | 0.2725 | 0.1710 | 0.2050 | 10 |
14 | -1 | 0.470 | 0.355 | 0.100 | 0.4755 | 0.1675 | 0.0805 | 0.1850 | 10 |
15 | 1 | 0.500 | 0.400 | 0.130 | 0.6645 | 0.2580 | 0.1330 | 0.2400 | 12 |
16 | 0 | 0.355 | 0.280 | 0.085 | 0.2905 | 0.0950 | 0.0395 | 0.1150 | 7 |
17 | -1 | 0.440 | 0.340 | 0.100 | 0.4510 | 0.1880 | 0.0870 | 0.1300 | 10 |
18 | 1 | 0.365 | 0.295 | 0.080 | 0.2555 | 0.0970 | 0.0430 | 0.1000 | 7 |
19 | 1 | 0.450 | 0.320 | 0.100 | 0.3810 | 0.1705 | 0.0750 | 0.1150 | 9 |
20 | 1 | 0.355 | 0.280 | 0.095 | 0.2455 | 0.0955 | 0.0620 | 0.0750 | 11 |
21 | 0 | 0.380 | 0.275 | 0.100 | 0.2255 | 0.0800 | 0.0490 | 0.0850 | 10 |
22 | -1 | 0.565 | 0.440 | 0.155 | 0.9395 | 0.4275 | 0.2140 | 0.2700 | 12 |
23 | -1 | 0.550 | 0.415 | 0.135 | 0.7635 | 0.3180 | 0.2100 | 0.2000 | 9 |
24 | -1 | 0.615 | 0.480 | 0.165 | 1.1615 | 0.5130 | 0.3010 | 0.3050 | 10 |
25 | -1 | 0.560 | 0.440 | 0.140 | 0.9285 | 0.3825 | 0.1880 | 0.3000 | 11 |
26 | -1 | 0.580 | 0.450 | 0.185 | 0.9955 | 0.3945 | 0.2720 | 0.2850 | 11 |
27 | 1 | 0.590 | 0.445 | 0.140 | 0.9310 | 0.3560 | 0.2340 | 0.2800 | 12 |
28 | 1 | 0.605 | 0.475 | 0.180 | 0.9365 | 0.3940 | 0.2190 | 0.2950 | 15 |
29 | 1 | 0.575 | 0.425 | 0.140 | 0.8635 | 0.3930 | 0.2270 | 0.2000 | 11 |
… | … | … | … | … | … | … | … | … | … |
4147 | 1 | 0.695 | 0.550 | 0.195 | 1.6645 | 0.7270 | 0.3600 | 0.4450 | 11 |
4148 | 1 | 0.770 | 0.605 | 0.175 | 2.0505 | 0.8005 | 0.5260 | 0.3550 | 11 |
4149 | 0 | 0.280 | 0.215 | 0.070 | 0.1240 | 0.0630 | 0.0215 | 0.0300 | 6 |
4150 | 0 | 0.330 | 0.230 | 0.080 | 0.1400 | 0.0565 | 0.0365 | 0.0460 | 7 |
4151 | 0 | 0.350 | 0.250 | 0.075 | 0.1695 | 0.0835 | 0.0355 | 0.0410 | 6 |
4152 | 0 | 0.370 | 0.280 | 0.090 | 0.2180 | 0.0995 | 0.0545 | 0.0615 | 7 |
4153 | 0 | 0.430 | 0.315 | 0.115 | 0.3840 | 0.1885 | 0.0715 | 0.1100 | 8 |
4154 | 0 | 0.435 | 0.330 | 0.095 | 0.3930 | 0.2190 | 0.0750 | 0.0885 | 6 |
4155 | 0 | 0.440 | 0.350 | 0.110 | 0.3805 | 0.1575 | 0.0895 | 0.1150 | 6 |
4156 | 1 | 0.475 | 0.370 | 0.110 | 0.4895 | 0.2185 | 0.1070 | 0.1460 | 8 |
4157 | 1 | 0.475 | 0.360 | 0.140 | 0.5135 | 0.2410 | 0.1045 | 0.1550 | 8 |
4158 | 0 | 0.480 | 0.355 | 0.110 | 0.4495 | 0.2010 | 0.0890 | 0.1400 | 8 |
4159 | -1 | 0.560 | 0.440 | 0.135 | 0.8025 | 0.3500 | 0.1615 | 0.2590 | 9 |
4160 | -1 | 0.585 | 0.475 | 0.165 | 1.0530 | 0.4580 | 0.2170 | 0.3000 | 11 |
4161 | -1 | 0.585 | 0.455 | 0.170 | 0.9945 | 0.4255 | 0.2630 | 0.2845 | 11 |
4162 | 1 | 0.385 | 0.255 | 0.100 | 0.3175 | 0.1370 | 0.0680 | 0.0920 | 8 |
4163 | 0 | 0.390 | 0.310 | 0.085 | 0.3440 | 0.1810 | 0.0695 | 0.0790 | 7 |
4164 | 0 | 0.390 | 0.290 | 0.100 | 0.2845 | 0.1255 | 0.0635 | 0.0810 | 7 |
4165 | 0 | 0.405 | 0.300 | 0.085 | 0.3035 | 0.1500 | 0.0505 | 0.0880 | 7 |
4166 | 0 | 0.475 | 0.365 | 0.115 | 0.4990 | 0.2320 | 0.0885 | 0.1560 | 10 |
4167 | 1 | 0.500 | 0.380 | 0.125 | 0.5770 | 0.2690 | 0.1265 | 0.1535 | 9 |
4168 | -1 | 0.515 | 0.400 | 0.125 | 0.6150 | 0.2865 | 0.1230 | 0.1765 | 8 |
4169 | 1 | 0.520 | 0.385 | 0.165 | 0.7910 | 0.3750 | 0.1800 | 0.1815 | 10 |
4170 | 1 | 0.550 | 0.430 | 0.130 | 0.8395 | 0.3155 | 0.1955 | 0.2405 | 10 |
4171 | 1 | 0.560 | 0.430 | 0.155 | 0.8675 | 0.4000 | 0.1720 | 0.2290 | 8 |
4172 | -1 | 0.565 | 0.450 | 0.165 | 0.8870 | 0.3700 | 0.2390 | 0.2490 | 11 |
4173 | 1 | 0.590 | 0.440 | 0.135 | 0.9660 | 0.4390 | 0.2145 | 0.2605 | 10 |
4174 | 1 | 0.600 | 0.475 | 0.205 | 1.1760 | 0.5255 | 0.2875 | 0.3080 | 9 |
4175 | -1 | 0.625 | 0.485 | 0.150 | 1.0945 | 0.5310 | 0.2610 | 0.2960 | 10 |
4176 | 1 | 0.710 | 0.555 | 0.195 | 1.9485 | 0.9455 | 0.3765 | 0.4950 | 12 |
4177 rows × 9 columns
abalone.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 4177 entries, 0 to 4176
Data columns (total 9 columns):
0 4177 non-null int64
1 4177 non-null float64
2 4177 non-null float64
3 4177 non-null float64
4 4177 non-null float64
5 4177 non-null float64
6 4177 non-null float64
7 4177 non-null float64
8 4177 non-null int64
dtypes: float64(7), int64(2)
memory usage: 293.8 KB
data,target = abalone.iloc[:,:-1],abalone[8]
from sklearn.model_selection import train_test_split
x_train,x_test,y_train,y_test = train_test_split(data,target,test_size=0.2,random_state=33)
from sklearn.linear_model import LinearRegression
lgr = LinearRegression()
lgr.fit(x_train,y_train)
LinearRegression(copy_X=True, fit_intercept=True, n_jobs=None,
normalize=False)
y_train_pred = lgr.predict(x_train)
y_test_pred = lgr.predict(x_test)
from sklearn.metrics import mean_squared_error
# 经验误差
mean_squared_error(y_pred=y_train_pred,y_true=y_train)
4.788743105709357
# 泛化误差
mean_squared_error(y_pred=y_test_pred,y_true=y_test)
5.400833865843481
查看预测标签和真实标签的对比
y_test_pred,y_test
(array([10.34562748, 12.76893803, 10.51757297, 9.35699601, 12.40099643,
9.96956442, 16.53271142, 8.75505569, 9.85857875, 8.8088983 ,
5.52115525, 13.21924253, 6.74794518, 10.59134682, 11.50256564,
5.42903798, 9.02509152, 9.2271902 , 9.84475711, 8.11947872,
9.61878283, 9.3539994 , 4.49745441, 8.74078674, 6.9193072 ,
7.03300145, 9.51341218, 9.25786869, 6.36914483, 9.90709391,
9.12821998, 9.64044939, 7.57439083, 9.14777883, 6.70700623,
7.19311105, 10.9766664 , 11.50916271, 8.26596128, 9.21464722,
15.72065136, 9.14721802, 8.61706087, 16.0908889 , 7.70064269,
14.04372113, 7.96114423, 11.38154855, 8.27955813, 9.95148543,
9.66246397, 14.01007638, 11.5552425 , 11.99003897, 9.18041889,
11.31471848, 7.54010381, 7.59676727, 8.79229094, 8.23529555,
14.93672265, 13.03401657, 8.9340301 , 12.4250975 , 8.6620213 ,
6.16781833, 9.18355357, 8.29315578, 8.73534441, 10.17968633,
9.05077372, 10.70145638, 7.74014836, 11.64467729, 10.22380578,
7.58076383, 9.20768887, 10.68789508, 11.07176169, 9.89633731,
10.39720067, 6.44630427, 14.32414918, 10.70168217, 5.65612814,
7.8334721 , 9.28804147, 12.07492257, 6.36003465, 11.10169728,
14.74266527, 9.5917279 , 8.01595203, 7.60872184, 7.89426983,
7.0434553 , 8.83057315, 7.69603938, 8.60007552, 9.5362269 ,
8.50645846, 8.3581402 , 10.44532009, 11.75963369, 8.08239627,
22.99534916, 8.44956123, 10.07699645, 7.10642895, 8.61704262,
8.90650343, 12.00071058, 4.58544581, 8.94759466, 7.46508244,
6.37639726, 9.56744653, 10.95996551, 19.17093153, 11.78419529,
10.30216824, 8.66676486, 8.50827877, 6.74476366, 17.8253854 ,
11.86193819, 10.68404293, 7.21853331, 8.78337896, 7.60362848,
7.70449167, 9.0791696 , 10.40239299, 8.24687014, 10.86123023,
8.24129673, 13.75913642, 6.26683464, 7.4080908 , 11.91298723,
12.79318882, 10.29407642, 11.41394843, 7.83175652, 7.52304812,
5.87673202, 8.69369741, 11.74719833, 11.43942735, 10.7405778 ,
9.2901268 , 7.88332673, 10.2846143 , 8.21931734, 7.99602385,
11.26246153, 9.38822795, 9.77840317, 9.22764449, 8.36514324,
14.80607221, 10.78189073, 12.0681189 , 13.18669636, 8.72083003,
9.14152678, 9.23643882, 9.20165143, 11.29722767, 10.13968778,
10.13770535, 4.79095144, 14.52565326, 11.41844158, 7.57285034,
12.05507262, 9.67518999, 6.75438673, 10.33233942, 8.25141407,
10.15042801, 8.89477557, 11.73169216, 7.65474894, 16.75653837,
10.76572629, 8.76706278, 9.36447505, 10.11643381, 13.86576779,
11.3345478 , 10.59291526, 15.99411597, 9.42801635, 8.86627716,
9.1063562 , 10.93480532, 12.88292979, 5.85400891, 8.65723615,
8.24369735, 11.42572667, 9.89808239, 8.89539417, 11.0903882 ,
10.17867834, 10.13465043, 6.59411311, 9.39686688, 7.41978867,
5.9748424 , 10.94614345, 10.26587309, 9.64968159, 9.51737495,
12.50391032, 8.9721678 , 8.85854543, 9.47459442, 11.05112905,
10.41488392, 12.90929434, 10.80101717, 7.06003044, 8.19973188,
11.0437563 , 10.27184748, 14.93222767, 7.1164997 , 9.78723485,
11.1075667 , 10.61025383, 9.08733991, 11.22579288, 16.89587247,
13.5448674 , 6.59437309, 9.87252229, 9.6420551 , 12.94940225,
16.80776927, 7.94581836, 9.13216288, 7.68005113, 9.02231094,
4.53072399, 6.11404307, 10.23691813, 10.98574462, 7.48689976,
8.90898647, 9.41941497, 7.69007985, 7.91575528, 11.05457928,
11.57218825, 5.85837268, 6.21919216, 10.18802958, 9.43340553,
10.43480437, 7.30373301, 9.710702 , 9.15944779, 6.90449288,
10.06086741, 8.46244175, 8.4567991 , 11.28159145, 13.18280751,
5.41342345, 9.87576857, 12.05688528, 10.40469332, 8.33878908,
7.52112953, 6.18841904, 11.10621726, 10.94558695, 9.29394031,
11.86845204, 14.3848581 , 8.13983095, 12.97781028, 10.28731957,
10.63549638, 10.37828637, 6.67571025, 6.29113602, 13.5047071 ,
9.90670018, 10.25684164, 9.89996042, 6.71867356, 10.15091574,
11.54141269, 6.10654062, 8.41172249, 9.7939284 , 8.63593295,
9.0231669 , 9.99900535, 7.97489453, 12.56887258, 7.30867039,
11.82325608, 6.51471656, 8.55990944, 7.66879371, 7.21015924,
6.72491316, 8.78687421, 11.99674862, 6.86149246, 10.28015529,
11.19577047, 8.04952436, 9.34727802, 8.44609408, 5.36033087,
9.41018527, 9.67104305, 7.20858121, 7.85873418, 9.35654238,
9.27639451, 10.02303972, 12.65497016, 5.26042786, 9.74674194,
7.94083535, 5.33033901, 9.22286018, 8.00879568, 10.71634599,
7.83819729, 6.40368649, 7.48263793, 16.63720659, 11.65842115,
14.58673875, 7.20018461, 9.44273145, 10.6900689 , 9.95604718,
8.69660536, 6.23933217, 15.12437049, 11.75962076, 9.85470139,
10.01469068, 10.9205047 , 6.96945823, 9.62429732, 9.39861119,
12.10905553, 7.1479343 , 12.99231475, 8.9439975 , 9.25432017,
7.46498721, 7.93858435, 9.99405556, 9.88467219, 15.39357449,
12.30731316, 9.95650368, 8.90284555, 8.25782025, 7.7850823 ,
14.13890674, 10.15529577, 13.64051703, 6.671871 , 9.64220517,
12.12535141, 9.57863712, 7.86216909, 11.44930096, 8.91163952,
8.55843793, 6.77029806, 9.28130721, 9.71478238, 9.91481275,
13.33652236, 7.43657106, 8.76188157, 12.15499552, 7.02680603,
10.91172134, 7.78542839, 13.70431539, 9.61252786, 10.03772817,
10.33303903, 9.7888133 , 8.27541502, 8.30626973, 9.62187144,
11.4980196 , 11.54493605, 8.52835116, 10.40106194, 6.84980625,
9.53666413, 5.05721826, 9.60654108, 16.10996475, 8.42823586,
11.36559873, 7.33966469, 10.16024593, 5.18983571, 10.35315024,
8.46199453, 11.31141046, 9.20915901, 8.54026585, 7.75214339,
12.55021747, 9.20619937, 9.16429721, 7.39544758, 10.43931142,
9.45445398, 10.56996274, 12.12860551, 7.446678 , 10.19972648,
9.73946411, 9.79825512, 11.65704415, 6.14132045, 13.94716034,
10.41136325, 8.43778589, 11.89597281, 11.61756321, 15.660813 ,
11.15593551, 14.41431771, 6.99092804, 10.08795285, 10.55563459,
10.95567835, 8.50718939, 8.07449085, 7.19427736, 12.74243369,
15.04852952, 5.5481162 , 10.19324921, 7.19050152, 9.68865271,
10.72626634, 5.96123723, 13.01847907, 10.09599092, 11.63147562,
10.68648304, 6.18256205, 11.47482149, 14.89379516, 9.23733159,
9.54792857, 11.32126467, 10.49715605, 5.31080804, 11.27802208,
8.62698204, 4.86761906, 9.29160101, 6.99999956, 8.53098422,
12.41015505, 12.8040076 , 10.9436541 , 11.14439778, 9.4173184 ,
11.62524134, 8.02387405, 10.44765333, 7.93195502, 8.76900177,
8.68047284, 9.80552181, 8.84577667, 8.42838737, 13.38374642,
8.46852124, 8.74843379, 7.26793744, 12.36247718, 10.03228762,
8.71213837, 14.21260352, 10.70318253, 8.58167185, 8.99351261,
8.22123824, 8.69921129, 9.21877251, 8.64980758, 6.90543364,
11.36101291, 11.96591184, 9.82072732, 10.68399448, 10.81703289,
10.03308353, 6.95182703, 8.63947771, 10.84307413, 10.64234729,
8.7783278 , 9.61335235, 12.85643974, 12.55872016, 10.28611753,
7.9693986 , 8.24556589, 10.86742808, 7.90205785, 11.45664819,
11.33442977, 9.56434287, 8.9444071 , 8.31509213, 10.65276168,
12.1616929 , 9.48380851, 10.43553827, 13.04695448, 9.7415096 ,
9.82294294, 9.08517964, 14.24382674, 10.61594523, 9.71600714,
10.28392301, 13.79721897, 10.56927627, 12.54728447, 9.21493853,
10.43492653, 10.56754141, 10.34120327, 11.29507484, 10.84752567,
10.16022385, 12.88798341, 9.12472937, 9.53381355, 13.13500151,
11.06890691, 12.90187368, 10.95615145, 8.43273966, 9.94255697,
14.38316344, 12.94626489, 10.76023586, 8.19222441, 5.94592641,
11.42347504, 10.66331285, 9.87597865, 11.21270134, 13.84869342,
8.06267015, 9.96495203, 8.8405467 , 5.8986351 , 10.97596867,
8.62930373, 9.99793435, 8.61183133, 12.01482031, 8.21957017,
7.47775915, 12.33167787, 15.4905633 , 12.5376807 , 9.17081378,
14.08115034, 10.75282432, 10.42021638, 7.61483407, 9.22444034,
10.63135592, 7.70302093, 12.21303167, 8.57008394, 7.04538404,
13.632883 , 9.84906044, 8.10686084, 8.38772905, 11.45433208,
8.72058359, 10.9949074 , 10.11175791, 8.49121239, 9.00690736,
7.64230464, 10.59889594, 13.13885648, 9.09230638, 10.58112202,
8.66701554, 13.87360662, 12.57924734, 9.17551288, 5.91298827,
12.46388829, 10.74016825, 5.92353788, 9.00650679, 9.87142408,
11.14876935, 9.71934326, 15.37099453, 10.1592695 , 10.88249155,
12.1930542 , 8.89507254, 12.87376853, 9.06026533, 8.52943353,
12.49172884, 8.54544227, 8.26545272, 14.30857713, 10.31471044,
10.41685954, 8.48694603, 13.13878359, 9.65727484, 12.53442581,
10.06723067, 8.6945423 , 7.74719372, 13.72931944, 13.48689231,
11.00443578, 11.93086951, 8.67023639, 9.84628969, 13.43474394,
12.18756227, 17.34895729, 5.32757108, 9.18622744, 9.42755469,
11.19297418, 9.76454177, 9.13717737, 6.55555197, 8.85251944,
16.23077192, 7.69692518, 10.58378648, 8.45074454, 10.34003858,
13.26228533, 6.25735906, 8.92170526, 9.60418024, 9.57802896,
12.52252003, 13.00841879, 8.56789533, 11.2112358 , 10.08352964,
6.31476677, 11.08159806, 10.05832564, 7.77634462, 8.02135321,
9.70155477, 11.26031489, 10.21152597, 8.40821524, 11.07494623,
15.81063347, 9.98673899, 10.2741934 , 9.45520992, 10.11128925,
7.53615704, 10.67307531, 8.92063899, 11.25374441, 7.69437595,
6.82143654, 4.58247139, 10.99464705, 13.24671261, 10.04270592,
8.29100981, 7.88280014, 13.80715026, 8.93550638, 10.19290486,
10.8270187 , 9.76494852, 11.72882782, 12.4136368 , 7.34303534,
6.95802357, 14.22413467, 10.66553667, 6.54081958, 11.58942981,
10.20412947, 9.95887261, 6.65266227, 13.17245896, 9.97441018,
7.47891011, 18.2137793 , 10.47957616, 9.6634571 , 11.41887348,
10.640508 , 14.40494159, 9.27983006, 13.18352717, 9.56951207,
5.98781165, 7.83519466, 8.98455646, 13.60209449, 10.46156459,
8.2822188 , 6.83223689, 12.39676773, 10.07313627, 12.24937664,
10.37092582, 13.19892103, 9.46623467, 14.80115694, 11.57301343,
8.10179312, 7.83502014, 11.317512 , 10.38801658, 9.14044004,
11.20573693, 9.97917144, 8.87623873, 8.81412447, 9.4809816 ,
10.89728041, 6.74731665, 6.27647383, 9.82524489, 7.52711846,
10.40735987, 12.09856071, 6.24109037, 10.75238838, 7.35071179,
16.12417932, 7.06972118, 7.02353697, 5.91208282, 11.74344756,
9.11920492, 7.36117875, 13.53889381, 10.75964421, 6.88399775,
8.10427422, 9.1502536 , 10.75427543, 9.0379462 , 11.87539714,
11.2898098 , 11.85604343, 10.18166835, 8.36326547, 5.43905537,
8.56991898, 8.09202743, 7.70933265, 9.86051613, 13.92502193,
11.35174334, 11.07093933, 12.77180737, 9.74257494, 12.77124404,
7.16304769, 13.34546963, 8.81576051, 10.5295328 , 11.82673759,
8.77883152, 12.03822361, 10.96821956, 8.7013583 , 9.21700899,
11.39043279, 8.92523665, 14.61354413, 9.37282327, 8.1419122 ,
7.64455142, 13.52341102, 8.31347214, 10.75681192, 9.72821914,
11.91266892, 7.1845481 , 13.64473908, 8.64619931, 7.06611872,
13.77764756, 6.71630228, 9.7022598 , 11.35620195, 15.04437288,
11.1991611 , 6.59374692, 7.85460462, 10.18873794, 7.86798428,
9.61320993, 12.09116906, 9.50087126, 7.27615432, 12.26593176,
8.17174245]), 2806 9
2251 13
3771 10
3819 8
1690 9
1326 8
2213 17
1087 8
3072 10
2133 9
1057 4
291 12
3474 6
2680 10
1739 13
520 3
3638 8
2121 9
1802 11
1148 8
1588 8
963 9
2114 4
3250 9
2005 7
347 6
2065 8
1647 10
241 5
2219 13
..
2749 9
2324 17
2530 9
3816 8
3534 6
2276 14
2016 9
864 11
2672 8
647 9
1229 6
3302 15
2878 8
565 10
373 14
2678 7
3907 13
2709 12
3219 16
3709 10
3630 7
2285 9
3915 11
436 7
3727 7
3212 14
2071 9
4162 8
1012 10
1079 7
Name: 8, Length: 836, dtype: int64)
# 某个渔民捕捉了一只鲍鱼,特征如下
import numpy as np
x = np.array([[1,12,23,9,10,5,6,8]])
lgr.predict(x)
array([422.46652202])