在学习的过程中,对于欠拟合和过拟合这两个概念总有点模糊,现在分享下自己对这两个概念的理解。
无论是在机器学习还是深度学习建模当中都可能会遇到两种最常见结果,一种是过拟合(over- fitting),另一种叫做欠拟合(under- fitting)。
所谓过拟合(over- fitting)其实就是所建的机器学习模型或者是深度学习模型在训练样本中表现得过于优越,导致在验证数据集以及测试数据集中表
现不佳。打个比喻就是当我需要建立好一个模型之后,比如是识别一只狗狗的模型,我需要对这个模型进行训练。恰好,我训练样本中的所有训练图片
都是二哈,那么经过多次迭代训练之后,模型训练好了,并且在训练集中表现得很好。基本上二哈身上的所有特点都涵括进去,那么问题来了!假如我
的测试样本是一只金毛呢?将一只金毛的测试样本放进这个识别狗狗的模型中,很有可能模型最后输出的结果就是金毛不是一条狗(因为这个模型基本
上是按照二哈的特征去打造的)。所以这样就造成了模型过拟合,虽然在训练集上表现得很好,但是在测试集中表现得恰好相反,在性能的角度上讲就
是协方差过大(variance is large),同样在测试集上的损失函数(cost function)会表现得很大。
而欠拟合,还是用刚刚的模型举例,就是可能二哈被提取的特征比较少,导致训练出来的模型不能很好地匹配,表现得很差,甚至二哈都无法识别。
在模型中,过拟合和欠拟合都不合适,其中造成过拟合的原因可以归结为:参数过多。对于过拟合,我们要做的事情就是减少参数。有两种办法,
第一种就是采用梯度下降算法将模型中的损失函数不断减少,那么我们就会在一定范围内求出最优解,最后损失函数不断趋近0 。我们可以再所定义的
损失函数后面加入一项永不为0 的部分,那么最后经过不断优化损失函数最终还是会存在。而这就是所谓的正则化。正则化方法包括L0正则、L1正则和L2
正则,而正则一般是在目标函数之后加上对于的范数。但是在机器学习中一般使用L2正则,下面看具体的原因。
L0范数是指向量中非0 的元素的个数。L1范数是指向量中各个元素绝对值之和,也叫“稀疏规则算子”(Lasso regularization)。两者都可以实现
稀疏性,既然L0可以实现稀疏,为什么不用L0,而要用L1呢?个人理解一是因为L0范数很难优化求解(NP难问题),二是L1范数是L0范数的最优凸近似,
而且它比L0范数要容易优化求解。所以大家才把目光和万千宠爱转于L1范数。
L2范数是指向量各元素的平方和然后求平方根。可以使得W的每个元素都很小,都接近于0 ,但与L1范数不同,它不会让它等于0 ,而是接近于0 。L2正
则项起到使得参数w变小加剧的效果,但是为什么可以防止过拟合呢?一个通俗的理解便是:更小的参数值w意味着模型的复杂度更低,对训练数据的拟合
刚刚好(奥卡姆剃刀),不会过分拟合训练数据,从而使得不会过拟合,以提高模型的泛化能力。
对于神经网络,参数膨胀原因可能是因为随着网路深度的增加,同时参数也不断增加,并且增加速度、规模都很大。那么可以采取减少神经网络规
模(深度)的方法。也可以用一种叫dropout的方法。dropout的思想是当一组参数经过某一层神经元的时候,去掉这一层上的一部分神经元,让参数只
经过一部分神经元进行计算。注意这里的去掉并不是真正意义上的去除,只是让参数不经过一部分神经元计算而已。即在训练时候以一定的概率p来跳过
一定的神经元。
另外增大训练样本规模同样也可以防止过拟合。
对于欠拟合,基本上都发生在训练刚刚开始的时候,对于这种,不断训练之后就会解决,但是如果还是没解决,可以通过增加网络复杂度或者增加
特征(如组合、泛化、相关性三类特征)来解决。
有数据集$\{ ( x_1, y_1) , ( x_2, y_2) , . . . , ( x_n, y_n) \} $, 其中, $x_i = ( x_{ i1} ; x_{ i2} ; x_{ i3} ; . . . ; x_{ id } ) , y_i\in R$< br>
其中n表示变量的数量,d表示每个变量的维度。
可以用以下函数来描述y和x之间的关系:
\begin{ align* }
f( x)
& = \theta_0 + \theta_1x_1 + \theta_2x_2 + . . . + \theta_dx_d \\
& = \sum_{ i= 0 } ^ { d} \theta_ix_i \\
\end{ align* }
如何来确定$\theta$的值,使得$f( x) $尽可能接近y的值呢?均方误差是回归中常用的性能度量,即:
$$J( \theta) = \frac{ 1 } { 2 } \sum_{ j= 1 } ^ { n} ( h_{ \theta} ( x^ { ( i) } ) - y^ { ( i) } ) ^ 2 $$< br>
我们可以选择$\theta$,试图让均方误差最小化。
* 损失函数( Loss Function) :度量单样本预测的错误程度,损失函数值越小,模型就越好。
* 代价函数( Cost Function) :度量全部样本集的平均误差。
* 目标函数( Object Function) :代价函数和正则化函数,最终要优化的函数。
当模型复杂度增加时,有可能对训练集可以模拟的很好,但是预测测试集的效果不好,出现过拟合现象,这就出现了所谓的“结构化风险”。结构风险最
小化即为了防止过拟合而提出来的策略。
1 、梯度下降法
2 、最小二乘法矩阵求解
3 、牛顿法
4 、拟牛顿法
均方误差( MSE) : $\frac{ 1 } { m} \sum ^ { m} _{ i= 1 } ( y^ { ( i) } - \hat y^ { ( i) } ) ^ 2 $
均方根误差( RMSE) :$\sqrt{ MSE} = \sqrt{ \frac{ 1 } { m} \sum ^ { m} _{ i= 1 } ( y^ { ( i) } - \hat y^ { ( i) } ) ^ 2 } $
平均绝对误差( MAE) :$\frac{ 1 } { m} \sum ^ { m} _{ i= 1 } | ( y^ { ( i) } - \hat y^ { ( i) } | $
但以上评价指标都无法消除量纲不一致而导致的误差值差别大的问题,最常用的指标是$R^ 2 $, 可以避免量纲不一致问题
$$R^ 2 : = 1 - \frac{ \sum ^ { m} _{ i= 1 } ( y^ { ( i) } - \hat y^ { ( i) } ) ^ 2 } { \sum ^ { m} _{ i= 1 } ( \bar y - \hat y^ { ( i) } ) ^ 2 } = 1 - \frac{ \frac{ 1 } { m} \sum ^ { m} _{ i= 1 } ( y^ { ( i) } - \hat y^ { ( i) } ) ^ 2 } { \frac{ 1 } { m} \sum ^ { m} _{ i= 1 } ( \bar y - \hat y^ { ( i) } ) ^ 2 } = 1 - \frac{ MSE} { VAR} $$
我们可以把$R^ 2 $理解为,回归模型可以成功解释的数据方差部分在数据固有方差中所占的比例,$R^ 2 $越接近1 ,表示可解释力度越大,模型拟合的效果越好。
fit_intercept : 默认为True , 是否计算该模型的截距。如果使用中心化的数据,可以考虑设置为False , 不考虑截距。注意这里是考虑,一般还是要考虑截距
normalize: 默认为false. 当fit_intercept设置为false的时候,这个参数会被自动忽略。如果为True , 回归器会标准化输入参数:减去平均值,并且除以相应的二范数。当然啦,在这里还是建议将标准化的工作放在训练模型之前。通过设置sklearn. preprocessing. StandardScaler来实现,而在此处设置为false
copy_X : 默认为True , 否则X会被改写
n_jobs: int 默认为1 . 当- 1 时默认使用全部CPUs ??( 这个参数有待尝试)
可用属性:
coef_: 训练后的输入端模型系数,如果label有两个,即y值有两列。那么是一个2D 的array
intercept_: 截距
可用的methods:
fit( X, y, sample_weight= None ) : X: array, 稀疏矩阵 [ n_samples, n_features] y: array [ n_samples, n_targets] sample_weight: 权重 array [ n_samples] 在版本0.17 后添加了sample_weight
get_params( deep= True ) : 返回对regressor 的设置值
predict( X) : 预测 基于 R^ 2 值
score: 评估
参考https: // blog. csdn. net/ weixin_39175124/ article/ details/ 79465558
练习题:请用以下数据(可自行生成尝试,或用其他已有数据集)
1 、首先尝试调用sklearn的线性回归函数进行训练;
2 、用最小二乘法的矩阵求解法训练数据;
3 、用梯度下降法训练数据;
4 、比较各方法得出的结果是否一致。
import numpy as np
np. random. seed( 1234 )
x= np. random. rand( 500 , 3 )
y = x. dot( np. array( [ 4.2 , 5.7 , 10.8 ] ) )
print ( x. shape)
print ( "==========" )
print ( y. shape)
print ( "==========" )
print ( x)
print ( "==========" )
print ( y)
(500, 3)
==========
(500,)
==========
[[0.19151945 0.62210877 0.43772774]
[0.78535858 0.77997581 0.27259261]
[0.27646426 0.80187218 0.95813935]
...
[0.77485433 0.17616405 0.88455879]
[0.09306395 0.20218845 0.37240548]
[0.17365571 0.34523374 0.56773848]]
==========
[ 9.07786127 10.68836829 16.0797263 11.12922286 10.93165403 5.37329728
12.21769522 6.99763629 11.92183168 11.25409795 14.79971809 12.22758633
10.96579249 6.1503401 10.46766117 8.55032711 5.45997705 5.8491499
16.33001627 18.18512486 9.92349225 3.58288103 8.83293151 11.05386043
5.06704001 11.94254694 6.45817975 14.56271085 11.87672677 11.83621297
9.58556266 12.16169141 7.54118908 14.83178201 16.95711344 7.2010966
9.41633216 13.36307267 8.89049525 13.53761433 16.59743526 5.81093283
5.94787817 8.45044079 7.09175585 8.17099288 4.67392227 15.3674629
10.1211939 6.69616873 10.66032044 13.32990072 15.52396023 11.07875215
19.80253898 11.41929488 15.69679186 12.79052379 9.04353678 15.40296188
14.50745991 5.85626504 10.82684019 6.46258723 6.29346456 6.11076124
19.33324427 17.32821695 7.57283027 7.07157768 10.187796 10.37520121
9.43062138 6.39246954 12.64853329 6.98462807 4.11688474 17.10762839
7.29038156 13.22771353 1.77970124 6.3287382 4.0317472 13.53221294
8.69087192 12.42392897 12.53767618 12.76175689 15.22297016 14.04727772
8.616245 12.60111427 8.7117879 3.84238311 9.47956126 14.8911322
8.14947582 18.70157352 11.12496826 8.30604101 11.48859452 7.55171731
14.24582487 7.62440733 2.94020209 12.60147821 16.53867848 12.26784659
12.49106895 1.58145023 7.53335905 8.35312782 16.01726403 17.38603735
13.85058557 12.81793394 9.56437868 11.94597849 13.96680295 11.35959062
8.89138927 12.36872469 13.87370378 11.47866185 7.38026689 13.77889773
10.53600207 12.70714639 5.1755306 14.16049317 5.32740025 16.36744632
9.57330081 13.17309885 7.97682816 7.13848392 9.37914495 17.08256052
16.58370831 10.8824037 17.00447988 15.34435017 8.92724986 8.04052215
11.45471907 8.49035479 16.31266802 4.91238327 8.81632288 10.30858435
17.18737034 12.41114677 12.05441828 7.84046494 9.16965441 6.63911287
4.44900393 14.27054657 5.17796142 9.56177954 8.06092304 14.23704375
6.36305329 10.55577213 7.29760177 15.36326758 13.84545457 7.51398571
6.18737077 7.59150763 7.45082811 3.06725252 3.23619985 17.97518396
13.8966688 7.87774061 13.40866891 11.16866861 8.94365227 14.04349463
5.70710662 2.90495664 11.06033867 6.47183888 15.72944232 14.77858947
3.01379814 7.26133446 11.78560762 4.47540678 14.19877081 6.04490669
6.52596256 15.16183336 8.60607703 8.10894149 10.60469598 12.4340502
8.19641355 7.86683848 9.52428868 10.63471422 5.84999185 11.13362284
12.03807586 14.20642003 13.30701635 9.06287704 7.5927058 6.29503364
6.83748614 14.1550358 3.89050322 6.86309022 11.34180739 7.81690429
8.653588 14.03381404 9.12136578 12.87247636 13.91368471 3.92660031
11.49918386 13.57872661 7.266045 13.23343461 4.51695465 10.88947142
9.6703235 12.30435404 5.79781909 7.51678521 5.11064867 8.74348701
7.52144318 15.42330473 13.08195222 17.69902463 3.95233236 13.22503427
12.1403755 10.2957699 14.69094853 11.76495674 8.82846793 11.47796772
16.90917638 7.20913035 15.59688176 10.61395437 12.19500175 7.26737684
14.38727016 4.90970646 17.38780676 13.8499805 10.01046353 10.44885638
12.28318046 11.43858545 9.31433305 9.13569732 13.15286545 15.67291545
7.059134 8.67174526 11.58980726 12.89134885 12.32677278 4.11751186
16.70218608 6.97725252 8.36580556 16.18702372 8.09229554 6.21669902
3.98406942 1.98196218 14.16913569 8.91632079 10.25114053 7.63008142
8.43121528 13.97176174 12.61543402 11.88917695 6.64408265 14.04694215
15.00204158 15.18389293 15.82521648 12.81690291 13.27980369 12.8437731
6.43473067 14.27035119 5.4526403 16.12822152 14.62087942 13.82040877
8.94835106 11.81700707 11.9986133 9.33416897 9.24511466 16.40475609
13.02131465 14.68450613 15.27679258 6.44360813 5.71296975 9.77059954
3.8105218 3.15587484 14.9119999 19.24217235 14.17740793 5.41669107
13.37314682 14.26892324 7.82425337 11.67318009 14.4771546 4.75718263
9.74268469 10.87269387 16.94751042 13.33386018 9.23221839 6.62370192
2.52853366 14.21920899 10.31141524 14.34137264 16.30216938 6.64873442
5.37188359 12.61177634 10.9604267 1.09988809 7.35931844 16.2175865
13.6684414 8.36041824 5.7772029 8.93803776 1.67842076 15.4932099
15.39896154 12.24330309 9.0038111 14.21214397 10.07681641 8.60155259
8.58035673 6.86130302 9.1453231 6.71812117 8.4273126 15.34094932
10.99750169 13.46709989 12.52967221 4.54192943 13.99820184 6.14951114
5.37476755 7.5241381 8.0100273 7.62400187 2.23269053 10.18443075
10.32586205 12.11134023 11.60334961 10.40338443 4.8377102 11.24398134
1.80999454 6.86553612 8.58169808 8.92831336 5.68473355 6.53590319
10.86454738 10.22698238 6.58416796 11.59687331 14.34580984 6.91268187
16.45453155 3.1938533 11.08694839 10.11585088 11.45350377 5.93132684
10.12523582 12.50005359 8.05235908 11.40820445 7.69528124 10.99800547
7.94707886 14.46864614 11.37516014 9.75999174 8.48222602 6.47793732
15.49080225 19.85328932 10.84161333 11.53964909 15.93937952 8.02325307
7.46809988 8.65408369 11.52905728 7.51484809 15.06917018 5.50553694
5.45241897 16.39477777 8.11388261 14.09355066 12.18547271 10.01100029
8.54996002 14.22018734 9.56864554 8.61050776 5.06333822 15.69041046
14.20310431 14.24074781 7.15211906 5.58081342 9.16546952 6.33297111
8.8284045 9.80980043 7.7805263 9.19089768 16.0185546 2.0615536
10.0797419 0.75048087 8.45990675 11.70503945 11.99904371 7.74981323
5.36557825 14.46811681 10.0957485 15.98480248 17.77061952 12.54082967
14.65196444 7.90290444 12.74817869 6.32518788 13.6284619 11.10327571
8.55884973 10.9534031 6.39979855 3.61559696 14.06951056 8.39888475
12.29549881 8.59204721 5.14732251 5.66267484 14.42632501 11.09256662
3.58985296 13.83743564 12.92581368 15.42819232 11.80182034 6.50789009
5.86769971 5.31620648 6.10747241 8.47060863 10.81127939 13.42070144
16.06476087 12.62894328 5.23980875 13.56023988 14.72669585 9.22601056
13.02814437 16.32149929 10.08875832 3.55205914 13.44140891 13.81175819
5.56532196 8.82876192]
import numpy as np
from sklearn. linear_model import LinearRegression
import matplotlib. pyplot as plt
% matplotlib inline
lr = LinearRegression( fit_intercept= True )
lr. fit( x, y)
print ( "估计的参数值为:%s" % ( lr. coef_) )
print ( 'R2:%s' % ( lr. score( x, y) ) )
x_test = np. array( [ 2 , 4 , 5 ] ) . reshape( 1 , - 1 )
y_hat = lr. predict( x_test)
print ( "预测值为: %s" % ( y_hat) )
估计的参数值为:[ 4.2 5.7 10.8]
R2:1.0
预测值为: [85.2]
class LR_LS ( ) :
def __init__ ( self) :
self. w = None
def fit ( self, X, y) :
self. w = np. linalg. inv( X. T. dot( X) ) . dot( X. T) . dot( y)
def predict ( self, X) :
y_pred = X. dot( self. w)
return y_pred
if __name__ == "__main__" :
lr_ls = LR_LS( )
lr_ls. fit( x, y)
print ( "估计的参数值:%s" % ( lr_ls. w) )
x_test = np. array( [ 2 , 4 , 5 ] ) . reshape( 1 , - 1 )
print ( "预测值为: %s" % ( lr_ls. predict( x_test) ) )
估计的参数值:[ 4.2 5.7 10.8]
预测值为: [85.2]
class LR_GD ( ) :
def __init__ ( self) :
self. w = None
def fit ( self, X, y, alpha= 0.02 , loss = 1e - 10 ) :
y = y. reshape( - 1 , 1 )
[ m, d] = np. shape( X)
self. w = np. zeros( ( d) )
tol = 1e5
while tol > loss:
h_f = X. dot( self. w) . reshape( - 1 , 1 )
theta = self. w + alpha* np. mean( X* ( y - h_f) , axis= 0 )
tol = np. sum ( np. abs ( theta - self. w) )
self. w = theta
def predict ( self, X) :
y_pred = X. dot( self. w)
return y_pred
if __name__ == "__main__" :
lr_gd = LR_GD( )
lr_gd. fit( x, y)
print ( "估计的参数值为:%s" % ( lr_gd. w) )
x_test = np. array( [ 2 , 4 , 5 ] ) . reshape( 1 , - 1 )
print ( "预测值为:%s" % ( lr_gd. predict( x_test) ) )
估计的参数值为:[ 4.20000001 5.70000003 10.79999997]
预测值为:[85.19999995]