1. 最小二乘原理
1.1 传统数学代码实现最小二乘
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import sklearn
from sklearn import datasets
import pandas as pd
# 导入糖尿病人数据,赋给变量d。
d = datasets.load_diabetes()
# 取数据集的所有行和第三列数据作为训练数据。
X = d.data[:,2]
# 将target数据赋给变量y。
Y = d.target
# 绘制散点图。
plt.scatter(X,Y)
plt.xlabel('X')
plt.ylabel('Y')
plt.show()
以下部分主要是实现最小二乘拟合线性模型 y = a x + b y=ax+b y=ax+b,求出参数a,b的方法。
#取X的行数为N。
N = X.shape[0]
#分别将矩阵中的各量表征出来。
S_x2 = np.sum(X*X)
S_x = np.sum(X)
S_xy = np.sum(X*Y)
S_y = np.sum(Y)
#写出A,B两个矩阵,为解方程组做准备。
A = np.array([[S_x2,S_x],
[S_x,N]])
B = np.array([S_xy,S_y])
#利用linalg模块进行解方程的操作,linalg模块的介绍在下文中会有详细的介绍。
jie = np.linalg.solve(A,B)
#输出解出的参数。
print('a=%f,b=%f' % (jie[0],jie[1]))
#求出x,y的范围。
x_min = np.min(X)
x_max = np.max(X)
y_min = jie[0] * x_min + jie[1]
y_max = jie[0] * x_max + jie[1]
#绘制散点图和最小二乘求出的直线。
plt.scatter(X,Y,label = 'orginal data')
plt.plot([x_min,x_max],[y_min,y_max],'r',label = 'model')
plt.legend()
plt.show()
a=949.435260,b=152.133484
使用梯度下降法(GD):
对于监督学习模型,需要对原始模型构建损失函数,然后基于优化算法对损失函数进行优化,以便寻找到最优的参数。
对于某一个损失函数:
L
=
∑
i
=
1
N
(
y
i
−
a
x
i
−
b
)
2
L=\sum_{i=1}^N(y_i-ax_i-b)^2
L=i=1∑N(yi−axi−b)2
更新策略:
θ
1
=
θ
0
−
α
▽
L
(
θ
)
\theta^1 = \theta^0 - \alpha\triangledown L(\theta)
θ1=θ0−α▽L(θ)
其中
θ
\theta
θ代表了模型中的参数,例如
a
,
b
a,b
a,b
关于更新策略,下降梯度即是损失函数的一阶导数。
#更新次数
n_cishu = 400
#初始化参数值
a,b = 1,1
#学习率,即步长
lv = 0.01
for i in range(n_cishu):
#N组数据进行更新,即参数更新策略。
for j in range(N):
a = a + lv*2*(Y[j]-a*X[j]-b)*X[j]
b = b + lv*2*(Y[j]-a*X[j]-b)
L=0
#损失函数更新,损失函数公式。
for j in range(N):
L = L + (Y[j]-a*X[j]-b)**2
print('cishu %4d: loss = %f, a = %f, b= %f' % (i,L,a,b))
x_min = np.min(X)
x_max = np.max(X)
y_min = a*x_min + b
y_max = a*x_max + b
plt.scatter(X,Y,label = 'original data')
plt.plot([x_min,x_max],[y_min,y_max],'r',label = 'model')
plt.legend()
plt.show()
cishu 0: loss = 2590736.867664, a = 18.689017, b= 148.815329
cishu 1: loss = 2557255.189453, a = 36.831348, b= 148.828655
cishu 2: loss = 2525130.800853, a = 54.614824, b= 148.822534
cishu 3: loss = 2494255.916234, a = 72.046438, b= 148.816532
cishu 4: loss = 2464581.753419, a = 89.133153, b= 148.810649
cishu 5: loss = 2436061.445196, a = 105.881792, b= 148.804882
cishu 6: loss = 2408649.957111, a = 122.299046, b= 148.799229
cishu 7: loss = 2382304.015732, a = 138.391470, b= 148.793687
cishu 8: loss = 2356982.039720, a = 154.165492, b= 148.788256
cishu 9: loss = 2332644.073594, a = 169.627411, b= 148.782932
cishu 10: loss = 2309251.724102, a = 184.783402, b= 148.777713
cishu 11: loss = 2286768.099073, a = 199.639520, b= 148.772598
cishu 12: loss = 2265157.748666, a = 214.201696, b= 148.767584
cishu 13: loss = 2244386.608923, a = 228.475748, b= 148.762669
cishu 14: loss = 2224421.947529, a = 242.467375, b= 148.757851
cishu 15: loss = 2205232.311697, a = 256.182166, b= 148.753129
cishu 16: loss = 2186787.478095, a = 269.625597, b= 148.748500
cishu 17: loss = 2169058.404731, a = 282.803039, b= 148.743962
cishu 18: loss = 2152017.184726, a = 295.719754, b= 148.739515
cishu 19: loss = 2135637.001889, a = 308.380901, b= 148.735155
cishu 20: loss = 2119892.088045, a = 320.791536, b= 148.730882
cishu 21: loss = 2104757.682020, a = 332.956616, b= 148.726693
cishu 22: loss = 2090209.990240, a = 344.881000, b= 148.722587
cishu 23: loss = 2076226.148871, a = 356.569449, b= 148.718562
cishu 24: loss = 2062784.187440, a = 368.026633, b= 148.714617
cishu 25: loss = 2049862.993883, a = 379.257126, b= 148.710750
cishu 26: loss = 2037442.280956, a = 390.265414, b= 148.706960
cishu 27: loss = 2025502.553972, a = 401.055894, b= 148.703244
cishu 28: loss = 2014025.079785, a = 411.632875, b= 148.699602
cishu 29: loss = 2002991.857005, a = 422.000581, b= 148.696032
cishu 30: loss = 1992385.587369, a = 432.163154, b= 148.692533
cishu 31: loss = 1982189.648231, a = 442.124651, b= 148.689103
cishu 32: loss = 1972388.066142, a = 451.889051, b= 148.685741
cishu 33: loss = 1962965.491447, a = 461.460255, b= 148.682445
cishu 34: loss = 1953907.173892, a = 470.842084, b= 148.679215
cishu 35: loss = 1945198.939172, a = 480.038285, b= 148.676048
cishu 36: loss = 1936827.166411, a = 489.052532, b= 148.672944
cishu 37: loss = 1928778.766513, a = 497.888424, b= 148.669902
cishu 38: loss = 1921041.161364, a = 506.549490, b= 148.666919
cishu 39: loss = 1913602.263848, a = 515.039190, b= 148.663996
cishu 40: loss = 1906450.458648, a = 523.360914, b= 148.661131
cishu 41: loss = 1899574.583794, a = 531.517985, b= 148.658322
cishu 42: loss = 1892963.912935, a = 539.513662, b= 148.655569
cishu 43: loss = 1886608.138306, a = 547.351138, b= 148.652870
cishu 44: loss = 1880497.354362, a = 555.033542, b= 148.650225
cishu 45: loss = 1874622.042047, a = 562.563943, b= 148.647632
cishu 46: loss = 1868973.053689, a = 569.945349, b= 148.645090
cishu 47: loss = 1863541.598476, a = 577.180708, b= 148.642599
cishu 48: loss = 1858319.228507, a = 584.272908, b= 148.640157
cishu 49: loss = 1853297.825385, a = 591.224784, b= 148.637763
cishu 50: loss = 1848469.587337, a = 598.039110, b= 148.635417
cishu 51: loss = 1843827.016840, a = 604.718609, b= 148.633117
cishu 52: loss = 1839362.908723, a = 611.265949, b= 148.630862
cishu 53: loss = 1835070.338746, a = 617.683744, b= 148.628653
cishu 54: loss = 1830942.652617, a = 623.974558, b= 148.626486
cishu 55: loss = 1826973.455442, a = 630.140902, b= 148.624363
cishu 56: loss = 1823156.601589, a = 636.185240, b= 148.622282
cishu 57: loss = 1819486.184948, a = 642.109985, b= 148.620242
cishu 58: loss = 1815956.529568, a = 647.917504, b= 148.618242
cishu 59: loss = 1812562.180670, a = 653.610117, b= 148.616282
cishu 60: loss = 1809297.895998, a = 659.190096, b= 148.614361
cishu 61: loss = 1806158.637522, a = 664.659671, b= 148.612477
cishu 62: loss = 1803139.563458, a = 670.021025, b= 148.610631
cishu 63: loss = 1800236.020598, a = 675.276301, b= 148.608822
cishu 64: loss = 1797443.536950, a = 680.427596, b= 148.607048
cishu 65: loss = 1794757.814654, a = 685.476969, b= 148.605309
cishu 66: loss = 1792174.723185, a = 690.426435, b= 148.603605
cishu 67: loss = 1789690.292815, a = 695.277972, b= 148.601934
cishu 68: loss = 1787300.708335, a = 700.033517, b= 148.600297
cishu 69: loss = 1785002.303019, a = 704.694970, b= 148.598692
cishu 70: loss = 1782791.552830, a = 709.264191, b= 148.597119
cishu 71: loss = 1780665.070844, a = 713.743007, b= 148.595576
cishu 72: loss = 1778619.601901, a = 718.133206, b= 148.594065
cishu 73: loss = 1776652.017457, a = 722.436540, b= 148.592583
cishu 74: loss = 1774759.310646, a = 726.654730, b= 148.591130
cishu 75: loss = 1772938.591527, a = 730.789459, b= 148.589707
cishu 76: loss = 1771187.082519, a = 734.842379, b= 148.588311
cishu 77: loss = 1769502.114022, a = 738.815108, b= 148.586943
cishu 78: loss = 1767881.120197, a = 742.709234, b= 148.585602
cishu 79: loss = 1766321.634920, a = 746.526311, b= 148.584288
cishu 80: loss = 1764821.287889, a = 750.267864, b= 148.583000
cishu 81: loss = 1763377.800890, a = 753.935387, b= 148.581737
cishu 82: loss = 1761988.984199, a = 757.530345, b= 148.580499
cishu 83: loss = 1760652.733134, a = 761.054173, b= 148.579286
cishu 84: loss = 1759367.024737, a = 764.508280, b= 148.578096
cishu 85: loss = 1758129.914584, a = 767.894044, b= 148.576931
cishu 86: loss = 1756939.533728, a = 771.212818, b= 148.575788
cishu 87: loss = 1755794.085750, a = 774.465927, b= 148.574668
cishu 88: loss = 1754691.843935, a = 777.654671, b= 148.573570
cishu 89: loss = 1753631.148552, a = 780.780322, b= 148.572493
cishu 90: loss = 1752610.404243, a = 783.844130, b= 148.571438
cishu 91: loss = 1751628.077518, a = 786.847318, b= 148.570404
cishu 92: loss = 1750682.694337, a = 789.791085, b= 148.569391
cishu 93: loss = 1749772.837797, a = 792.676607, b= 148.568397
cishu 94: loss = 1748897.145906, a = 795.505037, b= 148.567423
cishu 95: loss = 1748054.309442, a = 798.277504, b= 148.566469
cishu 96: loss = 1747243.069899, a = 800.995115, b= 148.565533
cishu 97: loss = 1746462.217508, a = 803.658956, b= 148.564616
cishu 98: loss = 1745710.589343, a = 806.270090, b= 148.563716
cishu 99: loss = 1744987.067493, a = 808.829561, b= 148.562835
cishu 100: loss = 1744290.577312, a = 811.338391, b= 148.561971
cishu 101: loss = 1743620.085732, a = 813.797581, b= 148.561125
cishu 102: loss = 1742974.599648, a = 816.208114, b= 148.560295
cishu 103: loss = 1742353.164357, a = 818.570953, b= 148.559481
cishu 104: loss = 1741754.862068, a = 820.887041, b= 148.558683
cishu 105: loss = 1741178.810465, a = 823.157303, b= 148.557902
cishu 106: loss = 1740624.161324, a = 825.382646, b= 148.557135
cishu 107: loss = 1740090.099189, a = 827.563959, b= 148.556384
cishu 108: loss = 1739575.840097, a = 829.702112, b= 148.555648
cishu 109: loss = 1739080.630352, a = 831.797961, b= 148.554926
cishu 110: loss = 1738603.745351, a = 833.852341, b= 148.554219
cishu 111: loss = 1738144.488447, a = 835.866073, b= 148.553526
cishu 112: loss = 1737702.189871, a = 837.839962, b= 148.552846
cishu 113: loss = 1737276.205677, a = 839.774796, b= 148.552180
cishu 114: loss = 1736865.916748, a = 841.671348, b= 148.551527
cishu 115: loss = 1736470.727823, a = 843.530375, b= 148.550887
cishu 116: loss = 1736090.066576, a = 845.352619, b= 148.550259
cishu 117: loss = 1735723.382723, a = 847.138809, b= 148.549644
cishu 118: loss = 1735370.147167, a = 848.889657, b= 148.549041
cishu 119: loss = 1735029.851171, a = 850.605863, b= 148.548450
cishu 120: loss = 1734702.005574, a = 852.288113, b= 148.547871
cishu 121: loss = 1734386.140028, a = 853.937078, b= 148.547303
cishu 122: loss = 1734081.802266, a = 855.553417, b= 148.546747
cishu 123: loss = 1733788.557403, a = 857.137775, b= 148.546201
cishu 124: loss = 1733505.987262, a = 858.690785, b= 148.545666
cishu 125: loss = 1733233.689723, a = 860.213068, b= 148.545142
cishu 126: loss = 1732971.278103, a = 861.705231, b= 148.544628
cishu 127: loss = 1732718.380559, a = 863.167870, b= 148.544125
cishu 128: loss = 1732474.639507, a = 864.601570, b= 148.543631
cishu 129: loss = 1732239.711074, a = 866.006902, b= 148.543147
cishu 130: loss = 1732013.264567, a = 867.384429, b= 148.542673
cishu 131: loss = 1731794.981959, a = 868.734701, b= 148.542208
cishu 132: loss = 1731584.557401, a = 870.058256, b= 148.541752
cishu 133: loss = 1731381.696752, a = 871.355623, b= 148.541306
cishu 134: loss = 1731186.117121, a = 872.627321, b= 148.540868
cishu 135: loss = 1730997.546436, a = 873.873858, b= 148.540438
cishu 136: loss = 1730815.723022, a = 875.095730, b= 148.540018
cishu 137: loss = 1730640.395202, a = 876.293427, b= 148.539605
cishu 138: loss = 1730471.320908, a = 877.467426, b= 148.539201
cishu 139: loss = 1730308.267311, a = 878.618197, b= 148.538805
cishu 140: loss = 1730151.010464, a = 879.746199, b= 148.538416
cishu 141: loss = 1729999.334955, a = 880.851882, b= 148.538036
cishu 142: loss = 1729853.033583, a = 881.935688, b= 148.537663
cishu 143: loss = 1729711.907037, a = 882.998050, b= 148.537297
cishu 144: loss = 1729575.763593, a = 884.039393, b= 148.536938
cishu 145: loss = 1729444.418820, a = 885.060132, b= 148.536587
cishu 146: loss = 1729317.695300, a = 886.060674, b= 148.536242
cishu 147: loss = 1729195.422355, a = 887.041420, b= 148.535904
cishu 148: loss = 1729077.435792, a = 888.002761, b= 148.535573
cishu 149: loss = 1728963.577645, a = 888.945081, b= 148.535249
cishu 150: loss = 1728853.695944, a = 889.868757, b= 148.534931
cishu 151: loss = 1728747.644477, a = 890.774156, b= 148.534619
cishu 152: loss = 1728645.282573, a = 891.661642, b= 148.534314
cishu 153: loss = 1728546.474885, a = 892.531568, b= 148.534014
cishu 154: loss = 1728451.091187, a = 893.384282, b= 148.533720
cishu 155: loss = 1728359.006178, a = 894.220124, b= 148.533433
cishu 156: loss = 1728270.099288, a = 895.039428, b= 148.533150
cishu 157: loss = 1728184.254503, a = 895.842521, b= 148.532874
cishu 158: loss = 1728101.360185, a = 896.629725, b= 148.532603
cishu 159: loss = 1728021.308903, a = 897.401353, b= 148.532337
cishu 160: loss = 1727943.997275, a = 898.157714, b= 148.532077
cishu 161: loss = 1727869.325813, a = 898.899110, b= 148.531821
cishu 162: loss = 1727797.198769, a = 899.625836, b= 148.531571
cishu 163: loss = 1727727.523995, a = 900.338184, b= 148.531326
cishu 164: loss = 1727660.212803, a = 901.036437, b= 148.531086
cishu 165: loss = 1727595.179837, a = 901.720875, b= 148.530850
cishu 166: loss = 1727532.342938, a = 902.391770, b= 148.530619
cishu 167: loss = 1727471.623027, a = 903.049391, b= 148.530392
cishu 168: loss = 1727412.943986, a = 903.694001, b= 148.530170
cishu 169: loss = 1727356.232544, a = 904.325856, b= 148.529953
cishu 170: loss = 1727301.418168, a = 904.945210, b= 148.529740
cishu 171: loss = 1727248.432959, a = 905.552309, b= 148.529531
cishu 172: loss = 1727197.211551, a = 906.147396, b= 148.529326
cishu 173: loss = 1727147.691014, a = 906.730709, b= 148.529125
cishu 174: loss = 1727099.810763, a = 907.302481, b= 148.528928
cishu 175: loss = 1727053.512464, a = 907.862939, b= 148.528735
cishu 176: loss = 1727008.739955, a = 908.412309, b= 148.528546
cishu 177: loss = 1726965.439156, a = 908.950808, b= 148.528360
cishu 178: loss = 1726923.557995, a = 909.478653, b= 148.528179
cishu 179: loss = 1726883.046328, a = 909.996055, b= 148.528000
cishu 180: loss = 1726843.855869, a = 910.503219, b= 148.527826
cishu 181: loss = 1726805.940116, a = 911.000348, b= 148.527655
cishu 182: loss = 1726769.254286, a = 911.487641, b= 148.527487
cishu 183: loss = 1726733.755249, a = 911.965292, b= 148.527322
cishu 184: loss = 1726699.401462, a = 912.433493, b= 148.527161
cishu 185: loss = 1726666.152915, a = 912.892430, b= 148.527003
cishu 186: loss = 1726633.971067, a = 913.342287, b= 148.526848
cishu 187: loss = 1726602.818794, a = 913.783243, b= 148.526696
cishu 188: loss = 1726572.660334, a = 914.215474, b= 148.526548
cishu 189: loss = 1726543.461235, a = 914.639153, b= 148.526402
cishu 190: loss = 1726515.188306, a = 915.054449, b= 148.526259
cishu 191: loss = 1726487.809572, a = 915.461529, b= 148.526119
cishu 192: loss = 1726461.294221, a = 915.860553, b= 148.525981
cishu 193: loss = 1726435.612569, a = 916.251683, b= 148.525846
cishu 194: loss = 1726410.736010, a = 916.635074, b= 148.525714
cishu 195: loss = 1726386.636980, a = 917.010879, b= 148.525585
cishu 196: loss = 1726363.288915, a = 917.379249, b= 148.525458
cishu 197: loss = 1726340.666217, a = 917.740330, b= 148.525334
cishu 198: loss = 1726318.744212, a = 918.094267, b= 148.525212
cishu 199: loss = 1726297.499121, a = 918.441201, b= 148.525093
cishu 200: loss = 1726276.908024, a = 918.781270, b= 148.524975
cishu 201: loss = 1726256.948827, a = 919.114611, b= 148.524861
cishu 202: loss = 1726237.600231, a = 919.441356, b= 148.524748
cishu 203: loss = 1726218.841703, a = 919.761637, b= 148.524638
cishu 204: loss = 1726200.653451, a = 920.075580, b= 148.524530
cishu 205: loss = 1726183.016388, a = 920.383312, b= 148.524424
cishu 206: loss = 1726165.912113, a = 920.684955, b= 148.524320
cishu 207: loss = 1726149.322881, a = 920.980630, b= 148.524218
cishu 208: loss = 1726133.231583, a = 921.270455, b= 148.524118
cishu 209: loss = 1726117.621717, a = 921.554545, b= 148.524021
cishu 210: loss = 1726102.477370, a = 921.833014, b= 148.523925
cishu 211: loss = 1726087.783192, a = 922.105974, b= 148.523831
cishu 212: loss = 1726073.524378, a = 922.373533, b= 148.523739
cishu 213: loss = 1726059.686650, a = 922.635798, b= 148.523648
cishu 214: loss = 1726046.256229, a = 922.892873, b= 148.523560
cishu 215: loss = 1726033.219826, a = 923.144863, b= 148.523473
cishu 216: loss = 1726020.564620, a = 923.391866, b= 148.523388
cishu 217: loss = 1726008.278237, a = 923.633982, b= 148.523305
cishu 218: loss = 1725996.348741, a = 923.871308, b= 148.523223
cishu 219: loss = 1725984.764612, a = 924.103938, b= 148.523143
cishu 220: loss = 1725973.514733, a = 924.331966, b= 148.523064
cishu 221: loss = 1725962.588374, a = 924.555481, b= 148.522987
cishu 222: loss = 1725951.975181, a = 924.774575, b= 148.522912
cishu 223: loss = 1725941.665156, a = 924.989333, b= 148.522838
cishu 224: loss = 1725931.648652, a = 925.199842, b= 148.522765
cishu 225: loss = 1725921.916352, a = 925.406186, b= 148.522694
cishu 226: loss = 1725912.459264, a = 925.608447, b= 148.522625
cishu 227: loss = 1725903.268703, a = 925.806706, b= 148.522556
cishu 228: loss = 1725894.336285, a = 926.001043, b= 148.522489
cishu 229: loss = 1725885.653913, a = 926.191534, b= 148.522424
cishu 230: loss = 1725877.213768, a = 926.378257, b= 148.522360
cishu 231: loss = 1725869.008296, a = 926.561285, b= 148.522297
cishu 232: loss = 1725861.030202, a = 926.740692, b= 148.522235
cishu 233: loss = 1725853.272442, a = 926.916548, b= 148.522174
cishu 234: loss = 1725845.728205, a = 927.088926, b= 148.522115
cishu 235: loss = 1725838.390917, a = 927.257893, b= 148.522057
cishu 236: loss = 1725831.254223, a = 927.423516, b= 148.522000
cishu 237: loss = 1725824.311982, a = 927.585863, b= 148.521944
cishu 238: loss = 1725817.558260, a = 927.744997, b= 148.521889
cishu 239: loss = 1725810.987324, a = 927.900983, b= 148.521835
cishu 240: loss = 1725804.593630, a = 928.053883, b= 148.521783
cishu 241: loss = 1725798.371821, a = 928.203757, b= 148.521731
cishu 242: loss = 1725792.316718, a = 928.350666, b= 148.521680
cishu 243: loss = 1725786.423315, a = 928.494668, b= 148.521631
cishu 244: loss = 1725780.686770, a = 928.635821, b= 148.521582
cishu 245: loss = 1725775.102403, a = 928.774181, b= 148.521535
cishu 246: loss = 1725769.665688, a = 928.909804, b= 148.521488
cishu 247: loss = 1725764.372248, a = 929.042743, b= 148.521442
cishu 248: loss = 1725759.217848, a = 929.173052, b= 148.521397
cishu 249: loss = 1725754.198393, a = 929.300783, b= 148.521353
cishu 250: loss = 1725749.309922, a = 929.425986, b= 148.521310
cishu 251: loss = 1725744.548603, a = 929.548712, b= 148.521268
cishu 252: loss = 1725739.910728, a = 929.669010, b= 148.521226
cishu 253: loss = 1725735.392708, a = 929.786928, b= 148.521186
cishu 254: loss = 1725730.991072, a = 929.902512, b= 148.521146
cishu 255: loss = 1725726.702459, a = 930.015810, b= 148.521107
cishu 256: loss = 1725722.523617, a = 930.126866, b= 148.521069
cishu 257: loss = 1725718.451397, a = 930.235724, b= 148.521031
cishu 258: loss = 1725714.482753, a = 930.342429, b= 148.520995
cishu 259: loss = 1725710.614734, a = 930.447022, b= 148.520959
cishu 260: loss = 1725706.844482, a = 930.549546, b= 148.520923
cishu 261: loss = 1725703.169232, a = 930.650042, b= 148.520889
cishu 262: loss = 1725699.586303, a = 930.748549, b= 148.520855
cishu 263: loss = 1725696.093102, a = 930.845107, b= 148.520822
cishu 264: loss = 1725692.687115, a = 930.939754, b= 148.520789
cishu 265: loss = 1725689.365908, a = 931.032529, b= 148.520757
cishu 266: loss = 1725686.127121, a = 931.123468, b= 148.520726
cishu 267: loss = 1725682.968469, a = 931.212608, b= 148.520695
cishu 268: loss = 1725679.887739, a = 931.299985, b= 148.520665
cishu 269: loss = 1725676.882782, a = 931.385632, b= 148.520635
cishu 270: loss = 1725673.951521, a = 931.469585, b= 148.520606
cishu 271: loss = 1725671.091938, a = 931.551877, b= 148.520578
cishu 272: loss = 1725668.302080, a = 931.632540, b= 148.520550
cishu 273: loss = 1725665.580052, a = 931.711608, b= 148.520523
cishu 274: loss = 1725662.924017, a = 931.789111, b= 148.520496
cishu 275: loss = 1725660.332194, a = 931.865080, b= 148.520470
cishu 276: loss = 1725657.802855, a = 931.939547, b= 148.520445
cishu 277: loss = 1725655.334325, a = 932.012540, b= 148.520420
cishu 278: loss = 1725652.924980, a = 932.084089, b= 148.520395
cishu 279: loss = 1725650.573243, a = 932.154222, b= 148.520371
cishu 280: loss = 1725648.277584, a = 932.222968, b= 148.520347
cishu 281: loss = 1725646.036521, a = 932.290353, b= 148.520324
cishu 282: loss = 1725643.848614, a = 932.356405, b= 148.520301
cishu 283: loss = 1725641.712465, a = 932.421150, b= 148.520279
cishu 284: loss = 1725639.626719, a = 932.484614, b= 148.520257
cishu 285: loss = 1725637.590060, a = 932.546823, b= 148.520236
cishu 286: loss = 1725635.601209, a = 932.607801, b= 148.520215
cishu 287: loss = 1725633.658927, a = 932.667572, b= 148.520194
cishu 288: loss = 1725631.762010, a = 932.726160, b= 148.520174
cishu 289: loss = 1725629.909289, a = 932.783590, b= 148.520154
cishu 290: loss = 1725628.099627, a = 932.839883, b= 148.520135
cishu 291: loss = 1725626.331922, a = 932.895062, b= 148.520116
cishu 292: loss = 1725624.605102, a = 932.949149, b= 148.520097
cishu 293: loss = 1725622.918128, a = 933.002166, b= 148.520079
cishu 294: loss = 1725621.269989, a = 933.054135, b= 148.520061
cishu 295: loss = 1725619.659702, a = 933.105075, b= 148.520043
cishu 296: loss = 1725618.086312, a = 933.155007, b= 148.520026
cishu 297: loss = 1725616.548894, a = 933.203951, b= 148.520009
cishu 298: loss = 1725615.046544, a = 933.251927, b= 148.519993
cishu 299: loss = 1725613.578388, a = 933.298953, b= 148.519977
cishu 300: loss = 1725612.143573, a = 933.345050, b= 148.519961
cishu 301: loss = 1725610.741272, a = 933.390234, b= 148.519945
cishu 302: loss = 1725609.370679, a = 933.434524, b= 148.519930
cishu 303: loss = 1725608.031013, a = 933.477937, b= 148.519915
cishu 304: loss = 1725606.721512, a = 933.520492, b= 148.519900
cishu 305: loss = 1725605.441435, a = 933.562205, b= 148.519886
cishu 306: loss = 1725604.190064, a = 933.603092, b= 148.519872
cishu 307: loss = 1725602.966697, a = 933.643171, b= 148.519858
cishu 308: loss = 1725601.770653, a = 933.682456, b= 148.519845
cishu 309: loss = 1725600.601270, a = 933.720964, b= 148.519831
cishu 310: loss = 1725599.457903, a = 933.758711, b= 148.519818
cishu 311: loss = 1725598.339923, a = 933.795710, b= 148.519806
cishu 312: loss = 1725597.246722, a = 933.831977, b= 148.519793
cishu 313: loss = 1725596.177703, a = 933.867527, b= 148.519781
cishu 314: loss = 1725595.132289, a = 933.902373, b= 148.519769
cishu 315: loss = 1725594.109917, a = 933.936530, b= 148.519757
cishu 316: loss = 1725593.110038, a = 933.970011, b= 148.519746
cishu 317: loss = 1725592.132118, a = 934.002830, b= 148.519734
cishu 318: loss = 1725591.175638, a = 934.034999, b= 148.519723
cishu 319: loss = 1725590.240092, a = 934.066532, b= 148.519712
cishu 320: loss = 1725589.324986, a = 934.097441, b= 148.519702
cishu 321: loss = 1725588.429841, a = 934.127738, b= 148.519691
cishu 322: loss = 1725587.554189, a = 934.157436, b= 148.519681
cishu 323: loss = 1725586.697574, a = 934.186546, b= 148.519671
cishu 324: loss = 1725585.859554, a = 934.215081, b= 148.519661
cishu 325: loss = 1725585.039694, a = 934.243051, b= 148.519651
cishu 326: loss = 1725584.237575, a = 934.270467, b= 148.519642
cishu 327: loss = 1725583.452786, a = 934.297341, b= 148.519633
cishu 328: loss = 1725582.684926, a = 934.323683, b= 148.519624
cishu 329: loss = 1725581.933606, a = 934.349504, b= 148.519615
cishu 330: loss = 1725581.198445, a = 934.374814, b= 148.519606
cishu 331: loss = 1725580.479074, a = 934.399623, b= 148.519598
cishu 332: loss = 1725579.775131, a = 934.423941, b= 148.519589
cishu 333: loss = 1725579.086264, a = 934.447779, b= 148.519581
cishu 334: loss = 1725578.412130, a = 934.471144, b= 148.519573
cishu 335: loss = 1725577.752394, a = 934.494048, b= 148.519565
cishu 336: loss = 1725577.106730, a = 934.516498, b= 148.519557
cishu 337: loss = 1725576.474819, a = 934.538504, b= 148.519550
cishu 338: loss = 1725575.856351, a = 934.560074, b= 148.519542
cishu 339: loss = 1725575.251023, a = 934.581218, b= 148.519535
cishu 340: loss = 1725574.658540, a = 934.601943, b= 148.519528
cishu 341: loss = 1725574.078613, a = 934.622259, b= 148.519521
cishu 342: loss = 1725573.510962, a = 934.642172, b= 148.519514
cishu 343: loss = 1725572.955312, a = 934.661691, b= 148.519507
cishu 344: loss = 1725572.411396, a = 934.680825, b= 148.519501
cishu 345: loss = 1725571.878952, a = 934.699579, b= 148.519494
cishu 346: loss = 1725571.357726, a = 934.717963, b= 148.519488
cishu 347: loss = 1725570.847468, a = 934.735982, b= 148.519482
cishu 348: loss = 1725570.347937, a = 934.753646, b= 148.519476
cishu 349: loss = 1725569.858895, a = 934.770959, b= 148.519470
cishu 350: loss = 1725569.380112, a = 934.787931, b= 148.519464
cishu 351: loss = 1725568.911360, a = 934.804566, b= 148.519458
cishu 352: loss = 1725568.452420, a = 934.820872, b= 148.519453
cishu 353: loss = 1725568.003077, a = 934.836856, b= 148.519447
cishu 354: loss = 1725567.563120, a = 934.852523, b= 148.519442
cishu 355: loss = 1725567.132345, a = 934.867881, b= 148.519436
cishu 356: loss = 1725566.710551, a = 934.882934, b= 148.519431
cishu 357: loss = 1725566.297542, a = 934.897690, b= 148.519426
cishu 358: loss = 1725565.893128, a = 934.912154, b= 148.519421
cishu 359: loss = 1725565.497121, a = 934.926331, b= 148.519416
cishu 360: loss = 1725565.109340, a = 934.940228, b= 148.519411
cishu 361: loss = 1725564.729607, a = 934.953850, b= 148.519407
cishu 362: loss = 1725564.357747, a = 934.967203, b= 148.519402
cishu 363: loss = 1725563.993590, a = 934.980291, b= 148.519398
cishu 364: loss = 1725563.636972, a = 934.993120, b= 148.519393
cishu 365: loss = 1725563.287729, a = 935.005696, b= 148.519389
cishu 366: loss = 1725562.945702, a = 935.018023, b= 148.519385
cishu 367: loss = 1725562.610739, a = 935.030106, b= 148.519380
cishu 368: loss = 1725562.282686, a = 935.041949, b= 148.519376
cishu 369: loss = 1725561.961396, a = 935.053559, b= 148.519372
cishu 370: loss = 1725561.646724, a = 935.064938, b= 148.519368
cishu 371: loss = 1725561.338531, a = 935.076093, b= 148.519365
cishu 372: loss = 1725561.036676, a = 935.087027, b= 148.519361
cishu 373: loss = 1725560.741026, a = 935.097744, b= 148.519357
cishu 374: loss = 1725560.451449, a = 935.108250, b= 148.519354
cishu 375: loss = 1725560.167816, a = 935.118547, b= 148.519350
cishu 376: loss = 1725559.890000, a = 935.128641, b= 148.519347
cishu 377: loss = 1725559.617879, a = 935.138535, b= 148.519343
cishu 378: loss = 1725559.351332, a = 935.148234, b= 148.519340
cishu 379: loss = 1725559.090241, a = 935.157740, b= 148.519337
cishu 380: loss = 1725558.834492, a = 935.167059, b= 148.519333
cishu 381: loss = 1725558.583972, a = 935.176193, b= 148.519330
cishu 382: loss = 1725558.338570, a = 935.185146, b= 148.519327
cishu 383: loss = 1725558.098179, a = 935.193922, b= 148.519324
cishu 384: loss = 1725557.862695, a = 935.202525, b= 148.519321
cishu 385: loss = 1725557.632013, a = 935.210957, b= 148.519318
cishu 386: loss = 1725557.406033, a = 935.219223, b= 148.519315
cishu 387: loss = 1725557.184657, a = 935.227324, b= 148.519313
cishu 388: loss = 1725556.967789, a = 935.235266, b= 148.519310
cishu 389: loss = 1725556.755334, a = 935.243051, b= 148.519307
cishu 390: loss = 1725556.547200, a = 935.250681, b= 148.519305
cishu 391: loss = 1725556.343298, a = 935.258160, b= 148.519302
cishu 392: loss = 1725556.143538, a = 935.265492, b= 148.519299
cishu 393: loss = 1725555.947836, a = 935.272678, b= 148.519297
cishu 394: loss = 1725555.756105, a = 935.279723, b= 148.519295
cishu 395: loss = 1725555.568265, a = 935.286628, b= 148.519292
cishu 396: loss = 1725555.384233, a = 935.293396, b= 148.519290
cishu 397: loss = 1725555.203932, a = 935.300030, b= 148.519288
cishu 398: loss = 1725555.027283, a = 935.306533, b= 148.519285
cishu 399: loss = 1725554.854212, a = 935.312908, b= 148.519283
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-NACdSawU-1622471030273)(output_6_1.png)]
1.2 动态可视化回归过程
FuncAnimation(fig,func,frames,init_func,interval,blit)
- fig:绘制动图的画布名称。
- func:自定义动画函数。
- frames:动画长度,一次循环包含的帧数,在函数运行时,会传递给函数func(n)的形参n。
- init_func:自定义开始帧,即传入刚定义的函数init,初始化函数。
- interval:更新频率,以ms计。blit选择更新所有点,还是仅更新变化的点。应选择True。
%matplotlib nbagg
import matplotlib.pyplot as plt
import matplotlib.animation as animation
n_cishu = 3000
a,b = 1,1
xuexi = 0.001
fig = plt.figure()
imgs = []
for i in range(n_cishu):
for j in range(N):
a = a + xuexi*2*(Y[j] - a*X[j]-b)*X[j]
b = b + xuexi*2*(Y[j] - a*X[j]-b)
L = 0
for j in range(N):
L = L + (Y[j] - a*X[j]-b)**2
if i % 50 == 0:
x_min = np.min(X)
x_max = np.max(X)
y_min = a*x_min+b
y_max = a*x_max+b
img = plt.scatter(X,Y,label='original data')
img = plt.plot([x_min,x_max],[y_min,y_max],'r',label='model')
imgs.append(img)
ani = animation.ArtistAnimation(fig,imgs)
plt.show()
<IPython.core.display.Javascript object>
![](https://img-blog.csdnimg.cn/2022010710465843395.png)
1.3 使用sklearn进行线性回归
线性回归模型:
lr = sklearn.linear_model.LinearRegression(fit_intercept=True, normalize=False, copy_X=True, n_jobs=1)
- fit_intercept:默认为True,是否计算模型的截距,否则数据中心化处理。
- normalize:默认为False,是否中心化处理,或者使用sklearn.preprocessing.StandardScaler()
- copy_X:默认为True,否则X会被改写。
- n_jobs:默认为1,使用cpu的个数。
调用方法:
- coef_:训练后的输入端模型系数。
- intercept_:截距
- predict(x):预测数据
- score:评估
#利用sklearn进行线性问题求解,即系数的求解。
from sklearn import linear_model
from sklearn import datasets
import numpy as np
d = datasets.load_diabetes()
X = d.data[:,np.newaxis,2]
Y = d.target
print('X:',X.shape)
print('Y:',Y.shape)
#线性模型的具体使用方法见上文。
regr = linear_model.LinearRegression()
regr.fit(X,Y)
a,b = regr.coef_,regr.intercept_
print('a= %f,b= %f' % (a,b))
x_min = np.min(X)
x_max = np.max(X)
y_min = a*x_min+b
y_max = a*x_max+b
plt.scatter(X,Y)
plt.plot([x_min,x_max],[y_min,y_max],'r')
plt.show()
X: (442, 1)
Y: (442,)
a= 949.435260,b= 152.133484
<IPython.core.display.Javascript object>
![](https://img-blog.csdnimg.cn/2022010710465825103.png)
2. 探究线性回归拟合多项式函数原理
使用sklearn拟合多项式函数
PolynomialFeatures进行特征的构造(特征与特征相乘),该函数有三个参数:
- degree:控制多项式的次数;
- interaction_only:默认为False,如果指定为True,那么就不会有特征自己和自己的结合项,组合特征中没 a 2 a^2 a2和 b 2 b^2 b2;
- include_bias:默认为True,如果为True的话,结果中就会有0次幂项,即全为1这一列。
Pipeline函数(流水线),把一系列的类连成一条流水线,然后让数据在流水线上“跑起来”。
class sklearn.pipeline.Pipeline(steps,memory=None,verbose=False):
- 设定流水线上一道道工序,并给一道道工序起一个名字。[(),()]类型,列表里面是一个个元组,分别为名字和工序,实现fit/transform的列表,最后一个对象是评估器。
- memory,类型:None、str等,用于缓存管道的被训练的转换器。默认情况下不执行缓存。可以用字符串指定缓存的路径。
- verbose:默认为False,用来显示每个流水线所消耗的时间,默认不显示。
调用的参数:
- names_steps:只读属性,可根据用户给定的名字访问任何步骤,键是步骤名,值是步骤的参数。
from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import LinearRegression
from sklearn.pipeline import Pipeline
t = np.array([2,4,6,8])
pa = -20
pb = 90
pc = 800
y = pa*t**2 + pb*t +pc
#通过pipeline快速实现多项式回归。
model = Pipeline([('poly',PolynomialFeatures(degree=2)),
('linear',LinearRegression(fit_intercept=False))])
model = model.fit(t[:,np.newaxis],y)
model.named_steps['linear'].coef_
array([800., 90., -20.])
Pipeline将算法连接起来。
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
from sklearn.preprocessing import PolynomialFeatures
from sklearn.pipeline import Pipeline
#从-3,3中随机取值
x = np.random.uniform(-3,3,size = 100)
#转换成一列
X1 = x.reshape(-1,1)
x
array([ 0.64084587, -0.11502597, -1.24388168, 0.36996037, -1.99241058,
-1.15287568, 0.95485448, 0.43372833, 0.12157036, -0.94437214,
1.47413249, 2.04227617, -1.36380038, -1.26772816, -0.55872249,
-1.21782647, 1.74740853, 1.84385472, -0.98296742, 2.97871352,
2.72478343, -1.33688328, -0.8511405 , 0.98721063, -2.6831202 ,
-2.40013065, -1.3489865 , -1.73266456, 1.60959222, 0.21356265,
0.3305497 , 1.65001695, -1.31356888, -2.00218193, -0.31918498,
0.29642117, 2.65375127, 0.9689322 , 1.52712763, 2.80490711,
-2.51476141, 1.09974238, -2.59777836, 1.61336388, 2.92498289,
1.49494967, -1.44983385, -0.36288847, -0.23628546, 2.24890589,
-2.07196042, 0.59035022, 0.8199342 , -0.80955785, 0.18998966,
-1.46055554, -0.92640668, 0.46121942, 0.93280307, -1.5079657 ,
-0.97797343, 0.53090615, 1.8992068 , 1.09406365, 1.14033716,
2.46727101, -2.62274806, 2.78741218, -2.62243954, 0.14901036,
1.34764166, -1.85746619, 1.50709718, 2.2149727 , -0.44419329,
0.44752385, -0.16081948, 0.14776611, -0.49505029, -1.090961 ,
-2.60175714, 0.41914586, -0.58771057, 2.20514719, 1.30988586,
-1.09294743, -0.52928834, 1.6574256 , 1.43164976, 1.77197508,
-1.89001765, 0.97692067, -0.86728969, 0.9114925 , 1.73473852,
-0.78685291, -1.94921809, 2.30754732, 2.94786966, 1.37511611])
X1[0:5,:]
array([[ 0.64084587],
[-0.11502597],
[-1.24388168],
[ 0.36996037],
[-1.99241058]])
y1 = 0.5*x**2 + x +2 + np.random.normal(0,2,size=100)
plt.scatter(x,y1)
<matplotlib.collections.PathCollection at 0x2b1aa13c1d0>
y1
array([ 0.96818306, 2.23883725, 1.92468482, -2.71637306, 4.45307092,
0.01816149, 1.83609337, 2.84743306, 1.6606926 , 3.32939878,
5.13270614, 6.94994308, 5.26852484, 0.776832 , -3.68122139,
-0.5092344 , 1.73890044, 8.53156829, 4.27045319, 10.39202826,
11.0183172 , -1.88915938, -1.60786658, 1.98119282, 3.45607561,
5.52254344, 2.62589771, -0.964355 , 6.13707036, 1.13207466,
1.00973905, 4.38597909, 0.8282406 , 0.31836746, 2.52730438,
6.13836595, 7.78835002, 6.40451125, 6.97002402, 7.66174213,
1.70348611, 5.42996596, 1.67880113, 1.21632926, 9.18677169,
5.29843559, -0.67039048, 3.26823983, 1.46620622, 3.56803585,
-2.83134549, 5.75070392, 1.0478809 , 0.30132542, 1.9288644 ,
4.98144168, 3.62208072, -0.06874818, 4.21452823, 4.88247452,
3.85579813, 1.16749376, 7.62196597, 4.08134595, 0.73896439,
7.45182763, 4.77644065, 8.85020327, 4.1570157 , 1.23410209,
9.68626177, -1.28932949, 5.22198367, 9.94091774, 1.9535322 ,
1.64324317, 0.28503905, 2.34850118, 2.47608781, 1.7209606 ,
0.26932514, 2.81809854, 1.31506573, 5.54638903, 0.90487236,
3.36809024, 0.43599388, 1.11588937, 2.35450458, 10.34477969,
1.64070228, 4.89975885, 0.0448951 , 2.01677536, 6.01236522,
1.80500177, 2.37439515, 3.61958116, 8.90862482, 0.606599 ])
#建立线性回归模型。
lr = LinearRegression()
#训练模型。
lr.fit(X1,y1)
#使用模型进行预测。
y_predict = lr.predict(X1)
plt.scatter(x,y1)
plt.plot(np.sort(x),y_predict[np.argsort(x)])
[<matplotlib.lines.Line2D at 0x2b1ab684860>]
poly = PolynomialFeatures(degree=2)
#为总特征添加二次特征进行一元二次拟合。
poly.fit(X1)
X2 = poly.transform(X1)
X2[0:5,:]
array([[ 1. , 0.64084587, 0.41068343],
[ 1. , -0.11502597, 0.01323097],
[ 1. , -1.24388168, 1.54724163],
[ 1. , 0.36996037, 0.13687068],
[ 1. , -1.99241058, 3.96969993]])
#可以看出上面的参数[1,a,a^2]。
0.64084587**2
0.41068342909605693
X1[0:5,:]
array([[ 0.64084587],
[-0.11502597],
[-1.24388168],
[ 0.36996037],
[-1.99241058]])
lr.fit(X2,y1)
y_predict2 = lr.predict(X2)
plt.scatter(x,y1)
#将数据进行一一对应绘图。
plt.plot(np.sort(x),y_predict2[np.argsort(x)])
[<matplotlib.lines.Line2D at 0x2b1ab85ec18>]
#将x从小到大进行排序,返回其排序后对应元素在原序列中的下标。
np.argsort(x)
array([24, 66, 68, 80, 42, 40, 25, 50, 33, 4, 96, 90, 71, 27, 59, 55, 46,
12, 26, 21, 32, 13, 2, 15, 5, 85, 79, 18, 60, 9, 56, 92, 22, 53,
95, 82, 14, 86, 78, 74, 47, 34, 48, 76, 1, 8, 77, 69, 54, 29, 35,
30, 3, 81, 7, 75, 57, 61, 51, 0, 52, 93, 58, 6, 37, 91, 23, 63,
41, 64, 84, 70, 99, 88, 10, 45, 72, 38, 28, 43, 31, 87, 94, 16, 89,
17, 62, 11, 83, 73, 49, 97, 65, 36, 20, 67, 39, 44, 98, 19],
dtype=int64)
#对序列进行从小到大的排序。
np.sort(x)
array([-2.6831202 , -2.62274806, -2.62243954, -2.60175714, -2.59777836,
-2.51476141, -2.40013065, -2.07196042, -2.00218193, -1.99241058,
-1.94921809, -1.89001765, -1.85746619, -1.73266456, -1.5079657 ,
-1.46055554, -1.44983385, -1.36380038, -1.3489865 , -1.33688328,
-1.31356888, -1.26772816, -1.24388168, -1.21782647, -1.15287568,
-1.09294743, -1.090961 , -0.98296742, -0.97797343, -0.94437214,
-0.92640668, -0.86728969, -0.8511405 , -0.80955785, -0.78685291,
-0.58771057, -0.55872249, -0.52928834, -0.49505029, -0.44419329,
-0.36288847, -0.31918498, -0.23628546, -0.16081948, -0.11502597,
0.12157036, 0.14776611, 0.14901036, 0.18998966, 0.21356265,
0.29642117, 0.3305497 , 0.36996037, 0.41914586, 0.43372833,
0.44752385, 0.46121942, 0.53090615, 0.59035022, 0.64084587,
0.8199342 , 0.9114925 , 0.93280307, 0.95485448, 0.9689322 ,
0.97692067, 0.98721063, 1.09406365, 1.09974238, 1.14033716,
1.30988586, 1.34764166, 1.37511611, 1.43164976, 1.47413249,
1.49494967, 1.50709718, 1.52712763, 1.60959222, 1.61336388,
1.65001695, 1.6574256 , 1.73473852, 1.74740853, 1.77197508,
1.84385472, 1.8992068 , 2.04227617, 2.20514719, 2.2149727 ,
2.24890589, 2.30754732, 2.46727101, 2.65375127, 2.72478343,
2.78741218, 2.80490711, 2.92498289, 2.94786966, 2.97871352])