机器学习线性回归中的最小二乘原理

1. 最小二乘原理

1.1 传统数学代码实现最小二乘

%matplotlib inline

import numpy as np
import matplotlib.pyplot as plt
import sklearn
from sklearn import datasets
import pandas as pd

# 导入糖尿病人数据,赋给变量d。
d = datasets.load_diabetes()
# 取数据集的所有行和第三列数据作为训练数据。
X = d.data[:,2]
# 将target数据赋给变量y。
Y = d.target

# 绘制散点图。
plt.scatter(X,Y)
plt.xlabel('X')
plt.ylabel('Y')
plt.show()

在这里插入图片描述

以下部分主要是实现最小二乘拟合线性模型 y = a x + b y=ax+b y=ax+b,求出参数a,b的方法。

#取X的行数为N。
N = X.shape[0]

#分别将矩阵中的各量表征出来。
S_x2 = np.sum(X*X)
S_x = np.sum(X)
S_xy = np.sum(X*Y)
S_y = np.sum(Y)

#写出A,B两个矩阵,为解方程组做准备。
A = np.array([[S_x2,S_x],
              [S_x,N]])
B = np.array([S_xy,S_y])

#利用linalg模块进行解方程的操作,linalg模块的介绍在下文中会有详细的介绍。
jie = np.linalg.solve(A,B)

#输出解出的参数。
print('a=%f,b=%f' % (jie[0],jie[1]))
#求出x,y的范围。
x_min = np.min(X)
x_max = np.max(X)
y_min = jie[0] * x_min + jie[1]
y_max = jie[0] * x_max + jie[1]

#绘制散点图和最小二乘求出的直线。
plt.scatter(X,Y,label = 'orginal data')
plt.plot([x_min,x_max],[y_min,y_max],'r',label = 'model')
plt.legend()
plt.show()
a=949.435260,b=152.133484

在这里插入图片描述

使用梯度下降法(GD):

对于监督学习模型,需要对原始模型构建损失函数,然后基于优化算法对损失函数进行优化,以便寻找到最优的参数。

对于某一个损失函数: L = ∑ i = 1 N ( y i − a x i − b ) 2 L=\sum_{i=1}^N(y_i-ax_i-b)^2 L=i=1N(yiaxib)2
更新策略: θ 1 = θ 0 − α ▽ L ( θ ) \theta^1 = \theta^0 - \alpha\triangledown L(\theta) θ1=θ0αL(θ)
其中 θ \theta θ代表了模型中的参数,例如 a , b a,b a,b
关于更新策略,下降梯度即是损失函数的一阶导数。

#更新次数
n_cishu = 400
#初始化参数值
a,b = 1,1
#学习率,即步长
lv = 0.01

for i in range(n_cishu):
    #N组数据进行更新,即参数更新策略。
    for j in range(N):
        a = a + lv*2*(Y[j]-a*X[j]-b)*X[j]
        b = b + lv*2*(Y[j]-a*X[j]-b)
        
    L=0
    #损失函数更新,损失函数公式。
    for j in range(N):
        L =  L + (Y[j]-a*X[j]-b)**2
    print('cishu %4d: loss = %f, a = %f, b= %f' % (i,L,a,b))
    
x_min = np.min(X)
x_max = np.max(X)
y_min = a*x_min + b
y_max = a*x_max + b

plt.scatter(X,Y,label = 'original data')
plt.plot([x_min,x_max],[y_min,y_max],'r',label = 'model')
plt.legend()
plt.show()
cishu    0: loss = 2590736.867664, a = 18.689017, b= 148.815329
cishu    1: loss = 2557255.189453, a = 36.831348, b= 148.828655
cishu    2: loss = 2525130.800853, a = 54.614824, b= 148.822534
cishu    3: loss = 2494255.916234, a = 72.046438, b= 148.816532
cishu    4: loss = 2464581.753419, a = 89.133153, b= 148.810649
cishu    5: loss = 2436061.445196, a = 105.881792, b= 148.804882
cishu    6: loss = 2408649.957111, a = 122.299046, b= 148.799229
cishu    7: loss = 2382304.015732, a = 138.391470, b= 148.793687
cishu    8: loss = 2356982.039720, a = 154.165492, b= 148.788256
cishu    9: loss = 2332644.073594, a = 169.627411, b= 148.782932
cishu   10: loss = 2309251.724102, a = 184.783402, b= 148.777713
cishu   11: loss = 2286768.099073, a = 199.639520, b= 148.772598
cishu   12: loss = 2265157.748666, a = 214.201696, b= 148.767584
cishu   13: loss = 2244386.608923, a = 228.475748, b= 148.762669
cishu   14: loss = 2224421.947529, a = 242.467375, b= 148.757851
cishu   15: loss = 2205232.311697, a = 256.182166, b= 148.753129
cishu   16: loss = 2186787.478095, a = 269.625597, b= 148.748500
cishu   17: loss = 2169058.404731, a = 282.803039, b= 148.743962
cishu   18: loss = 2152017.184726, a = 295.719754, b= 148.739515
cishu   19: loss = 2135637.001889, a = 308.380901, b= 148.735155
cishu   20: loss = 2119892.088045, a = 320.791536, b= 148.730882
cishu   21: loss = 2104757.682020, a = 332.956616, b= 148.726693
cishu   22: loss = 2090209.990240, a = 344.881000, b= 148.722587
cishu   23: loss = 2076226.148871, a = 356.569449, b= 148.718562
cishu   24: loss = 2062784.187440, a = 368.026633, b= 148.714617
cishu   25: loss = 2049862.993883, a = 379.257126, b= 148.710750
cishu   26: loss = 2037442.280956, a = 390.265414, b= 148.706960
cishu   27: loss = 2025502.553972, a = 401.055894, b= 148.703244
cishu   28: loss = 2014025.079785, a = 411.632875, b= 148.699602
cishu   29: loss = 2002991.857005, a = 422.000581, b= 148.696032
cishu   30: loss = 1992385.587369, a = 432.163154, b= 148.692533
cishu   31: loss = 1982189.648231, a = 442.124651, b= 148.689103
cishu   32: loss = 1972388.066142, a = 451.889051, b= 148.685741
cishu   33: loss = 1962965.491447, a = 461.460255, b= 148.682445
cishu   34: loss = 1953907.173892, a = 470.842084, b= 148.679215
cishu   35: loss = 1945198.939172, a = 480.038285, b= 148.676048
cishu   36: loss = 1936827.166411, a = 489.052532, b= 148.672944
cishu   37: loss = 1928778.766513, a = 497.888424, b= 148.669902
cishu   38: loss = 1921041.161364, a = 506.549490, b= 148.666919
cishu   39: loss = 1913602.263848, a = 515.039190, b= 148.663996
cishu   40: loss = 1906450.458648, a = 523.360914, b= 148.661131
cishu   41: loss = 1899574.583794, a = 531.517985, b= 148.658322
cishu   42: loss = 1892963.912935, a = 539.513662, b= 148.655569
cishu   43: loss = 1886608.138306, a = 547.351138, b= 148.652870
cishu   44: loss = 1880497.354362, a = 555.033542, b= 148.650225
cishu   45: loss = 1874622.042047, a = 562.563943, b= 148.647632
cishu   46: loss = 1868973.053689, a = 569.945349, b= 148.645090
cishu   47: loss = 1863541.598476, a = 577.180708, b= 148.642599
cishu   48: loss = 1858319.228507, a = 584.272908, b= 148.640157
cishu   49: loss = 1853297.825385, a = 591.224784, b= 148.637763
cishu   50: loss = 1848469.587337, a = 598.039110, b= 148.635417
cishu   51: loss = 1843827.016840, a = 604.718609, b= 148.633117
cishu   52: loss = 1839362.908723, a = 611.265949, b= 148.630862
cishu   53: loss = 1835070.338746, a = 617.683744, b= 148.628653
cishu   54: loss = 1830942.652617, a = 623.974558, b= 148.626486
cishu   55: loss = 1826973.455442, a = 630.140902, b= 148.624363
cishu   56: loss = 1823156.601589, a = 636.185240, b= 148.622282
cishu   57: loss = 1819486.184948, a = 642.109985, b= 148.620242
cishu   58: loss = 1815956.529568, a = 647.917504, b= 148.618242
cishu   59: loss = 1812562.180670, a = 653.610117, b= 148.616282
cishu   60: loss = 1809297.895998, a = 659.190096, b= 148.614361
cishu   61: loss = 1806158.637522, a = 664.659671, b= 148.612477
cishu   62: loss = 1803139.563458, a = 670.021025, b= 148.610631
cishu   63: loss = 1800236.020598, a = 675.276301, b= 148.608822
cishu   64: loss = 1797443.536950, a = 680.427596, b= 148.607048
cishu   65: loss = 1794757.814654, a = 685.476969, b= 148.605309
cishu   66: loss = 1792174.723185, a = 690.426435, b= 148.603605
cishu   67: loss = 1789690.292815, a = 695.277972, b= 148.601934
cishu   68: loss = 1787300.708335, a = 700.033517, b= 148.600297
cishu   69: loss = 1785002.303019, a = 704.694970, b= 148.598692
cishu   70: loss = 1782791.552830, a = 709.264191, b= 148.597119
cishu   71: loss = 1780665.070844, a = 713.743007, b= 148.595576
cishu   72: loss = 1778619.601901, a = 718.133206, b= 148.594065
cishu   73: loss = 1776652.017457, a = 722.436540, b= 148.592583
cishu   74: loss = 1774759.310646, a = 726.654730, b= 148.591130
cishu   75: loss = 1772938.591527, a = 730.789459, b= 148.589707
cishu   76: loss = 1771187.082519, a = 734.842379, b= 148.588311
cishu   77: loss = 1769502.114022, a = 738.815108, b= 148.586943
cishu   78: loss = 1767881.120197, a = 742.709234, b= 148.585602
cishu   79: loss = 1766321.634920, a = 746.526311, b= 148.584288
cishu   80: loss = 1764821.287889, a = 750.267864, b= 148.583000
cishu   81: loss = 1763377.800890, a = 753.935387, b= 148.581737
cishu   82: loss = 1761988.984199, a = 757.530345, b= 148.580499
cishu   83: loss = 1760652.733134, a = 761.054173, b= 148.579286
cishu   84: loss = 1759367.024737, a = 764.508280, b= 148.578096
cishu   85: loss = 1758129.914584, a = 767.894044, b= 148.576931
cishu   86: loss = 1756939.533728, a = 771.212818, b= 148.575788
cishu   87: loss = 1755794.085750, a = 774.465927, b= 148.574668
cishu   88: loss = 1754691.843935, a = 777.654671, b= 148.573570
cishu   89: loss = 1753631.148552, a = 780.780322, b= 148.572493
cishu   90: loss = 1752610.404243, a = 783.844130, b= 148.571438
cishu   91: loss = 1751628.077518, a = 786.847318, b= 148.570404
cishu   92: loss = 1750682.694337, a = 789.791085, b= 148.569391
cishu   93: loss = 1749772.837797, a = 792.676607, b= 148.568397
cishu   94: loss = 1748897.145906, a = 795.505037, b= 148.567423
cishu   95: loss = 1748054.309442, a = 798.277504, b= 148.566469
cishu   96: loss = 1747243.069899, a = 800.995115, b= 148.565533
cishu   97: loss = 1746462.217508, a = 803.658956, b= 148.564616
cishu   98: loss = 1745710.589343, a = 806.270090, b= 148.563716
cishu   99: loss = 1744987.067493, a = 808.829561, b= 148.562835
cishu  100: loss = 1744290.577312, a = 811.338391, b= 148.561971
cishu  101: loss = 1743620.085732, a = 813.797581, b= 148.561125
cishu  102: loss = 1742974.599648, a = 816.208114, b= 148.560295
cishu  103: loss = 1742353.164357, a = 818.570953, b= 148.559481
cishu  104: loss = 1741754.862068, a = 820.887041, b= 148.558683
cishu  105: loss = 1741178.810465, a = 823.157303, b= 148.557902
cishu  106: loss = 1740624.161324, a = 825.382646, b= 148.557135
cishu  107: loss = 1740090.099189, a = 827.563959, b= 148.556384
cishu  108: loss = 1739575.840097, a = 829.702112, b= 148.555648
cishu  109: loss = 1739080.630352, a = 831.797961, b= 148.554926
cishu  110: loss = 1738603.745351, a = 833.852341, b= 148.554219
cishu  111: loss = 1738144.488447, a = 835.866073, b= 148.553526
cishu  112: loss = 1737702.189871, a = 837.839962, b= 148.552846
cishu  113: loss = 1737276.205677, a = 839.774796, b= 148.552180
cishu  114: loss = 1736865.916748, a = 841.671348, b= 148.551527
cishu  115: loss = 1736470.727823, a = 843.530375, b= 148.550887
cishu  116: loss = 1736090.066576, a = 845.352619, b= 148.550259
cishu  117: loss = 1735723.382723, a = 847.138809, b= 148.549644
cishu  118: loss = 1735370.147167, a = 848.889657, b= 148.549041
cishu  119: loss = 1735029.851171, a = 850.605863, b= 148.548450
cishu  120: loss = 1734702.005574, a = 852.288113, b= 148.547871
cishu  121: loss = 1734386.140028, a = 853.937078, b= 148.547303
cishu  122: loss = 1734081.802266, a = 855.553417, b= 148.546747
cishu  123: loss = 1733788.557403, a = 857.137775, b= 148.546201
cishu  124: loss = 1733505.987262, a = 858.690785, b= 148.545666
cishu  125: loss = 1733233.689723, a = 860.213068, b= 148.545142
cishu  126: loss = 1732971.278103, a = 861.705231, b= 148.544628
cishu  127: loss = 1732718.380559, a = 863.167870, b= 148.544125
cishu  128: loss = 1732474.639507, a = 864.601570, b= 148.543631
cishu  129: loss = 1732239.711074, a = 866.006902, b= 148.543147
cishu  130: loss = 1732013.264567, a = 867.384429, b= 148.542673
cishu  131: loss = 1731794.981959, a = 868.734701, b= 148.542208
cishu  132: loss = 1731584.557401, a = 870.058256, b= 148.541752
cishu  133: loss = 1731381.696752, a = 871.355623, b= 148.541306
cishu  134: loss = 1731186.117121, a = 872.627321, b= 148.540868
cishu  135: loss = 1730997.546436, a = 873.873858, b= 148.540438
cishu  136: loss = 1730815.723022, a = 875.095730, b= 148.540018
cishu  137: loss = 1730640.395202, a = 876.293427, b= 148.539605
cishu  138: loss = 1730471.320908, a = 877.467426, b= 148.539201
cishu  139: loss = 1730308.267311, a = 878.618197, b= 148.538805
cishu  140: loss = 1730151.010464, a = 879.746199, b= 148.538416
cishu  141: loss = 1729999.334955, a = 880.851882, b= 148.538036
cishu  142: loss = 1729853.033583, a = 881.935688, b= 148.537663
cishu  143: loss = 1729711.907037, a = 882.998050, b= 148.537297
cishu  144: loss = 1729575.763593, a = 884.039393, b= 148.536938
cishu  145: loss = 1729444.418820, a = 885.060132, b= 148.536587
cishu  146: loss = 1729317.695300, a = 886.060674, b= 148.536242
cishu  147: loss = 1729195.422355, a = 887.041420, b= 148.535904
cishu  148: loss = 1729077.435792, a = 888.002761, b= 148.535573
cishu  149: loss = 1728963.577645, a = 888.945081, b= 148.535249
cishu  150: loss = 1728853.695944, a = 889.868757, b= 148.534931
cishu  151: loss = 1728747.644477, a = 890.774156, b= 148.534619
cishu  152: loss = 1728645.282573, a = 891.661642, b= 148.534314
cishu  153: loss = 1728546.474885, a = 892.531568, b= 148.534014
cishu  154: loss = 1728451.091187, a = 893.384282, b= 148.533720
cishu  155: loss = 1728359.006178, a = 894.220124, b= 148.533433
cishu  156: loss = 1728270.099288, a = 895.039428, b= 148.533150
cishu  157: loss = 1728184.254503, a = 895.842521, b= 148.532874
cishu  158: loss = 1728101.360185, a = 896.629725, b= 148.532603
cishu  159: loss = 1728021.308903, a = 897.401353, b= 148.532337
cishu  160: loss = 1727943.997275, a = 898.157714, b= 148.532077
cishu  161: loss = 1727869.325813, a = 898.899110, b= 148.531821
cishu  162: loss = 1727797.198769, a = 899.625836, b= 148.531571
cishu  163: loss = 1727727.523995, a = 900.338184, b= 148.531326
cishu  164: loss = 1727660.212803, a = 901.036437, b= 148.531086
cishu  165: loss = 1727595.179837, a = 901.720875, b= 148.530850
cishu  166: loss = 1727532.342938, a = 902.391770, b= 148.530619
cishu  167: loss = 1727471.623027, a = 903.049391, b= 148.530392
cishu  168: loss = 1727412.943986, a = 903.694001, b= 148.530170
cishu  169: loss = 1727356.232544, a = 904.325856, b= 148.529953
cishu  170: loss = 1727301.418168, a = 904.945210, b= 148.529740
cishu  171: loss = 1727248.432959, a = 905.552309, b= 148.529531
cishu  172: loss = 1727197.211551, a = 906.147396, b= 148.529326
cishu  173: loss = 1727147.691014, a = 906.730709, b= 148.529125
cishu  174: loss = 1727099.810763, a = 907.302481, b= 148.528928
cishu  175: loss = 1727053.512464, a = 907.862939, b= 148.528735
cishu  176: loss = 1727008.739955, a = 908.412309, b= 148.528546
cishu  177: loss = 1726965.439156, a = 908.950808, b= 148.528360
cishu  178: loss = 1726923.557995, a = 909.478653, b= 148.528179
cishu  179: loss = 1726883.046328, a = 909.996055, b= 148.528000
cishu  180: loss = 1726843.855869, a = 910.503219, b= 148.527826
cishu  181: loss = 1726805.940116, a = 911.000348, b= 148.527655
cishu  182: loss = 1726769.254286, a = 911.487641, b= 148.527487
cishu  183: loss = 1726733.755249, a = 911.965292, b= 148.527322
cishu  184: loss = 1726699.401462, a = 912.433493, b= 148.527161
cishu  185: loss = 1726666.152915, a = 912.892430, b= 148.527003
cishu  186: loss = 1726633.971067, a = 913.342287, b= 148.526848
cishu  187: loss = 1726602.818794, a = 913.783243, b= 148.526696
cishu  188: loss = 1726572.660334, a = 914.215474, b= 148.526548
cishu  189: loss = 1726543.461235, a = 914.639153, b= 148.526402
cishu  190: loss = 1726515.188306, a = 915.054449, b= 148.526259
cishu  191: loss = 1726487.809572, a = 915.461529, b= 148.526119
cishu  192: loss = 1726461.294221, a = 915.860553, b= 148.525981
cishu  193: loss = 1726435.612569, a = 916.251683, b= 148.525846
cishu  194: loss = 1726410.736010, a = 916.635074, b= 148.525714
cishu  195: loss = 1726386.636980, a = 917.010879, b= 148.525585
cishu  196: loss = 1726363.288915, a = 917.379249, b= 148.525458
cishu  197: loss = 1726340.666217, a = 917.740330, b= 148.525334
cishu  198: loss = 1726318.744212, a = 918.094267, b= 148.525212
cishu  199: loss = 1726297.499121, a = 918.441201, b= 148.525093
cishu  200: loss = 1726276.908024, a = 918.781270, b= 148.524975
cishu  201: loss = 1726256.948827, a = 919.114611, b= 148.524861
cishu  202: loss = 1726237.600231, a = 919.441356, b= 148.524748
cishu  203: loss = 1726218.841703, a = 919.761637, b= 148.524638
cishu  204: loss = 1726200.653451, a = 920.075580, b= 148.524530
cishu  205: loss = 1726183.016388, a = 920.383312, b= 148.524424
cishu  206: loss = 1726165.912113, a = 920.684955, b= 148.524320
cishu  207: loss = 1726149.322881, a = 920.980630, b= 148.524218
cishu  208: loss = 1726133.231583, a = 921.270455, b= 148.524118
cishu  209: loss = 1726117.621717, a = 921.554545, b= 148.524021
cishu  210: loss = 1726102.477370, a = 921.833014, b= 148.523925
cishu  211: loss = 1726087.783192, a = 922.105974, b= 148.523831
cishu  212: loss = 1726073.524378, a = 922.373533, b= 148.523739
cishu  213: loss = 1726059.686650, a = 922.635798, b= 148.523648
cishu  214: loss = 1726046.256229, a = 922.892873, b= 148.523560
cishu  215: loss = 1726033.219826, a = 923.144863, b= 148.523473
cishu  216: loss = 1726020.564620, a = 923.391866, b= 148.523388
cishu  217: loss = 1726008.278237, a = 923.633982, b= 148.523305
cishu  218: loss = 1725996.348741, a = 923.871308, b= 148.523223
cishu  219: loss = 1725984.764612, a = 924.103938, b= 148.523143
cishu  220: loss = 1725973.514733, a = 924.331966, b= 148.523064
cishu  221: loss = 1725962.588374, a = 924.555481, b= 148.522987
cishu  222: loss = 1725951.975181, a = 924.774575, b= 148.522912
cishu  223: loss = 1725941.665156, a = 924.989333, b= 148.522838
cishu  224: loss = 1725931.648652, a = 925.199842, b= 148.522765
cishu  225: loss = 1725921.916352, a = 925.406186, b= 148.522694
cishu  226: loss = 1725912.459264, a = 925.608447, b= 148.522625
cishu  227: loss = 1725903.268703, a = 925.806706, b= 148.522556
cishu  228: loss = 1725894.336285, a = 926.001043, b= 148.522489
cishu  229: loss = 1725885.653913, a = 926.191534, b= 148.522424
cishu  230: loss = 1725877.213768, a = 926.378257, b= 148.522360
cishu  231: loss = 1725869.008296, a = 926.561285, b= 148.522297
cishu  232: loss = 1725861.030202, a = 926.740692, b= 148.522235
cishu  233: loss = 1725853.272442, a = 926.916548, b= 148.522174
cishu  234: loss = 1725845.728205, a = 927.088926, b= 148.522115
cishu  235: loss = 1725838.390917, a = 927.257893, b= 148.522057
cishu  236: loss = 1725831.254223, a = 927.423516, b= 148.522000
cishu  237: loss = 1725824.311982, a = 927.585863, b= 148.521944
cishu  238: loss = 1725817.558260, a = 927.744997, b= 148.521889
cishu  239: loss = 1725810.987324, a = 927.900983, b= 148.521835
cishu  240: loss = 1725804.593630, a = 928.053883, b= 148.521783
cishu  241: loss = 1725798.371821, a = 928.203757, b= 148.521731
cishu  242: loss = 1725792.316718, a = 928.350666, b= 148.521680
cishu  243: loss = 1725786.423315, a = 928.494668, b= 148.521631
cishu  244: loss = 1725780.686770, a = 928.635821, b= 148.521582
cishu  245: loss = 1725775.102403, a = 928.774181, b= 148.521535
cishu  246: loss = 1725769.665688, a = 928.909804, b= 148.521488
cishu  247: loss = 1725764.372248, a = 929.042743, b= 148.521442
cishu  248: loss = 1725759.217848, a = 929.173052, b= 148.521397
cishu  249: loss = 1725754.198393, a = 929.300783, b= 148.521353
cishu  250: loss = 1725749.309922, a = 929.425986, b= 148.521310
cishu  251: loss = 1725744.548603, a = 929.548712, b= 148.521268
cishu  252: loss = 1725739.910728, a = 929.669010, b= 148.521226
cishu  253: loss = 1725735.392708, a = 929.786928, b= 148.521186
cishu  254: loss = 1725730.991072, a = 929.902512, b= 148.521146
cishu  255: loss = 1725726.702459, a = 930.015810, b= 148.521107
cishu  256: loss = 1725722.523617, a = 930.126866, b= 148.521069
cishu  257: loss = 1725718.451397, a = 930.235724, b= 148.521031
cishu  258: loss = 1725714.482753, a = 930.342429, b= 148.520995
cishu  259: loss = 1725710.614734, a = 930.447022, b= 148.520959
cishu  260: loss = 1725706.844482, a = 930.549546, b= 148.520923
cishu  261: loss = 1725703.169232, a = 930.650042, b= 148.520889
cishu  262: loss = 1725699.586303, a = 930.748549, b= 148.520855
cishu  263: loss = 1725696.093102, a = 930.845107, b= 148.520822
cishu  264: loss = 1725692.687115, a = 930.939754, b= 148.520789
cishu  265: loss = 1725689.365908, a = 931.032529, b= 148.520757
cishu  266: loss = 1725686.127121, a = 931.123468, b= 148.520726
cishu  267: loss = 1725682.968469, a = 931.212608, b= 148.520695
cishu  268: loss = 1725679.887739, a = 931.299985, b= 148.520665
cishu  269: loss = 1725676.882782, a = 931.385632, b= 148.520635
cishu  270: loss = 1725673.951521, a = 931.469585, b= 148.520606
cishu  271: loss = 1725671.091938, a = 931.551877, b= 148.520578
cishu  272: loss = 1725668.302080, a = 931.632540, b= 148.520550
cishu  273: loss = 1725665.580052, a = 931.711608, b= 148.520523
cishu  274: loss = 1725662.924017, a = 931.789111, b= 148.520496
cishu  275: loss = 1725660.332194, a = 931.865080, b= 148.520470
cishu  276: loss = 1725657.802855, a = 931.939547, b= 148.520445
cishu  277: loss = 1725655.334325, a = 932.012540, b= 148.520420
cishu  278: loss = 1725652.924980, a = 932.084089, b= 148.520395
cishu  279: loss = 1725650.573243, a = 932.154222, b= 148.520371
cishu  280: loss = 1725648.277584, a = 932.222968, b= 148.520347
cishu  281: loss = 1725646.036521, a = 932.290353, b= 148.520324
cishu  282: loss = 1725643.848614, a = 932.356405, b= 148.520301
cishu  283: loss = 1725641.712465, a = 932.421150, b= 148.520279
cishu  284: loss = 1725639.626719, a = 932.484614, b= 148.520257
cishu  285: loss = 1725637.590060, a = 932.546823, b= 148.520236
cishu  286: loss = 1725635.601209, a = 932.607801, b= 148.520215
cishu  287: loss = 1725633.658927, a = 932.667572, b= 148.520194
cishu  288: loss = 1725631.762010, a = 932.726160, b= 148.520174
cishu  289: loss = 1725629.909289, a = 932.783590, b= 148.520154
cishu  290: loss = 1725628.099627, a = 932.839883, b= 148.520135
cishu  291: loss = 1725626.331922, a = 932.895062, b= 148.520116
cishu  292: loss = 1725624.605102, a = 932.949149, b= 148.520097
cishu  293: loss = 1725622.918128, a = 933.002166, b= 148.520079
cishu  294: loss = 1725621.269989, a = 933.054135, b= 148.520061
cishu  295: loss = 1725619.659702, a = 933.105075, b= 148.520043
cishu  296: loss = 1725618.086312, a = 933.155007, b= 148.520026
cishu  297: loss = 1725616.548894, a = 933.203951, b= 148.520009
cishu  298: loss = 1725615.046544, a = 933.251927, b= 148.519993
cishu  299: loss = 1725613.578388, a = 933.298953, b= 148.519977
cishu  300: loss = 1725612.143573, a = 933.345050, b= 148.519961
cishu  301: loss = 1725610.741272, a = 933.390234, b= 148.519945
cishu  302: loss = 1725609.370679, a = 933.434524, b= 148.519930
cishu  303: loss = 1725608.031013, a = 933.477937, b= 148.519915
cishu  304: loss = 1725606.721512, a = 933.520492, b= 148.519900
cishu  305: loss = 1725605.441435, a = 933.562205, b= 148.519886
cishu  306: loss = 1725604.190064, a = 933.603092, b= 148.519872
cishu  307: loss = 1725602.966697, a = 933.643171, b= 148.519858
cishu  308: loss = 1725601.770653, a = 933.682456, b= 148.519845
cishu  309: loss = 1725600.601270, a = 933.720964, b= 148.519831
cishu  310: loss = 1725599.457903, a = 933.758711, b= 148.519818
cishu  311: loss = 1725598.339923, a = 933.795710, b= 148.519806
cishu  312: loss = 1725597.246722, a = 933.831977, b= 148.519793
cishu  313: loss = 1725596.177703, a = 933.867527, b= 148.519781
cishu  314: loss = 1725595.132289, a = 933.902373, b= 148.519769
cishu  315: loss = 1725594.109917, a = 933.936530, b= 148.519757
cishu  316: loss = 1725593.110038, a = 933.970011, b= 148.519746
cishu  317: loss = 1725592.132118, a = 934.002830, b= 148.519734
cishu  318: loss = 1725591.175638, a = 934.034999, b= 148.519723
cishu  319: loss = 1725590.240092, a = 934.066532, b= 148.519712
cishu  320: loss = 1725589.324986, a = 934.097441, b= 148.519702
cishu  321: loss = 1725588.429841, a = 934.127738, b= 148.519691
cishu  322: loss = 1725587.554189, a = 934.157436, b= 148.519681
cishu  323: loss = 1725586.697574, a = 934.186546, b= 148.519671
cishu  324: loss = 1725585.859554, a = 934.215081, b= 148.519661
cishu  325: loss = 1725585.039694, a = 934.243051, b= 148.519651
cishu  326: loss = 1725584.237575, a = 934.270467, b= 148.519642
cishu  327: loss = 1725583.452786, a = 934.297341, b= 148.519633
cishu  328: loss = 1725582.684926, a = 934.323683, b= 148.519624
cishu  329: loss = 1725581.933606, a = 934.349504, b= 148.519615
cishu  330: loss = 1725581.198445, a = 934.374814, b= 148.519606
cishu  331: loss = 1725580.479074, a = 934.399623, b= 148.519598
cishu  332: loss = 1725579.775131, a = 934.423941, b= 148.519589
cishu  333: loss = 1725579.086264, a = 934.447779, b= 148.519581
cishu  334: loss = 1725578.412130, a = 934.471144, b= 148.519573
cishu  335: loss = 1725577.752394, a = 934.494048, b= 148.519565
cishu  336: loss = 1725577.106730, a = 934.516498, b= 148.519557
cishu  337: loss = 1725576.474819, a = 934.538504, b= 148.519550
cishu  338: loss = 1725575.856351, a = 934.560074, b= 148.519542
cishu  339: loss = 1725575.251023, a = 934.581218, b= 148.519535
cishu  340: loss = 1725574.658540, a = 934.601943, b= 148.519528
cishu  341: loss = 1725574.078613, a = 934.622259, b= 148.519521
cishu  342: loss = 1725573.510962, a = 934.642172, b= 148.519514
cishu  343: loss = 1725572.955312, a = 934.661691, b= 148.519507
cishu  344: loss = 1725572.411396, a = 934.680825, b= 148.519501
cishu  345: loss = 1725571.878952, a = 934.699579, b= 148.519494
cishu  346: loss = 1725571.357726, a = 934.717963, b= 148.519488
cishu  347: loss = 1725570.847468, a = 934.735982, b= 148.519482
cishu  348: loss = 1725570.347937, a = 934.753646, b= 148.519476
cishu  349: loss = 1725569.858895, a = 934.770959, b= 148.519470
cishu  350: loss = 1725569.380112, a = 934.787931, b= 148.519464
cishu  351: loss = 1725568.911360, a = 934.804566, b= 148.519458
cishu  352: loss = 1725568.452420, a = 934.820872, b= 148.519453
cishu  353: loss = 1725568.003077, a = 934.836856, b= 148.519447
cishu  354: loss = 1725567.563120, a = 934.852523, b= 148.519442
cishu  355: loss = 1725567.132345, a = 934.867881, b= 148.519436
cishu  356: loss = 1725566.710551, a = 934.882934, b= 148.519431
cishu  357: loss = 1725566.297542, a = 934.897690, b= 148.519426
cishu  358: loss = 1725565.893128, a = 934.912154, b= 148.519421
cishu  359: loss = 1725565.497121, a = 934.926331, b= 148.519416
cishu  360: loss = 1725565.109340, a = 934.940228, b= 148.519411
cishu  361: loss = 1725564.729607, a = 934.953850, b= 148.519407
cishu  362: loss = 1725564.357747, a = 934.967203, b= 148.519402
cishu  363: loss = 1725563.993590, a = 934.980291, b= 148.519398
cishu  364: loss = 1725563.636972, a = 934.993120, b= 148.519393
cishu  365: loss = 1725563.287729, a = 935.005696, b= 148.519389
cishu  366: loss = 1725562.945702, a = 935.018023, b= 148.519385
cishu  367: loss = 1725562.610739, a = 935.030106, b= 148.519380
cishu  368: loss = 1725562.282686, a = 935.041949, b= 148.519376
cishu  369: loss = 1725561.961396, a = 935.053559, b= 148.519372
cishu  370: loss = 1725561.646724, a = 935.064938, b= 148.519368
cishu  371: loss = 1725561.338531, a = 935.076093, b= 148.519365
cishu  372: loss = 1725561.036676, a = 935.087027, b= 148.519361
cishu  373: loss = 1725560.741026, a = 935.097744, b= 148.519357
cishu  374: loss = 1725560.451449, a = 935.108250, b= 148.519354
cishu  375: loss = 1725560.167816, a = 935.118547, b= 148.519350
cishu  376: loss = 1725559.890000, a = 935.128641, b= 148.519347
cishu  377: loss = 1725559.617879, a = 935.138535, b= 148.519343
cishu  378: loss = 1725559.351332, a = 935.148234, b= 148.519340
cishu  379: loss = 1725559.090241, a = 935.157740, b= 148.519337
cishu  380: loss = 1725558.834492, a = 935.167059, b= 148.519333
cishu  381: loss = 1725558.583972, a = 935.176193, b= 148.519330
cishu  382: loss = 1725558.338570, a = 935.185146, b= 148.519327
cishu  383: loss = 1725558.098179, a = 935.193922, b= 148.519324
cishu  384: loss = 1725557.862695, a = 935.202525, b= 148.519321
cishu  385: loss = 1725557.632013, a = 935.210957, b= 148.519318
cishu  386: loss = 1725557.406033, a = 935.219223, b= 148.519315
cishu  387: loss = 1725557.184657, a = 935.227324, b= 148.519313
cishu  388: loss = 1725556.967789, a = 935.235266, b= 148.519310
cishu  389: loss = 1725556.755334, a = 935.243051, b= 148.519307
cishu  390: loss = 1725556.547200, a = 935.250681, b= 148.519305
cishu  391: loss = 1725556.343298, a = 935.258160, b= 148.519302
cishu  392: loss = 1725556.143538, a = 935.265492, b= 148.519299
cishu  393: loss = 1725555.947836, a = 935.272678, b= 148.519297
cishu  394: loss = 1725555.756105, a = 935.279723, b= 148.519295
cishu  395: loss = 1725555.568265, a = 935.286628, b= 148.519292
cishu  396: loss = 1725555.384233, a = 935.293396, b= 148.519290
cishu  397: loss = 1725555.203932, a = 935.300030, b= 148.519288
cishu  398: loss = 1725555.027283, a = 935.306533, b= 148.519285
cishu  399: loss = 1725554.854212, a = 935.312908, b= 148.519283

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-NACdSawU-1622471030273)(output_6_1.png)]

1.2 动态可视化回归过程

FuncAnimation(fig,func,frames,init_func,interval,blit)

  • fig:绘制动图的画布名称。
  • func:自定义动画函数。
  • frames:动画长度,一次循环包含的帧数,在函数运行时,会传递给函数func(n)的形参n。
  • init_func:自定义开始帧,即传入刚定义的函数init,初始化函数。
  • interval:更新频率,以ms计。blit选择更新所有点,还是仅更新变化的点。应选择True。
%matplotlib nbagg

import matplotlib.pyplot as plt
import matplotlib.animation as animation

n_cishu = 3000
a,b = 1,1
xuexi = 0.001

fig = plt.figure()
imgs = []

for i in range(n_cishu):
    for j in range(N):
        a = a + xuexi*2*(Y[j] - a*X[j]-b)*X[j]
        b = b + xuexi*2*(Y[j] - a*X[j]-b)
        
    L = 0
    for j in range(N):
        L = L + (Y[j] - a*X[j]-b)**2
        
    if i % 50 == 0:
        x_min = np.min(X)
        x_max = np.max(X)
        y_min = a*x_min+b
        y_max = a*x_max+b
        
        img = plt.scatter(X,Y,label='original data')
        img = plt.plot([x_min,x_max],[y_min,y_max],'r',label='model')
        imgs.append(img)
        
ani = animation.ArtistAnimation(fig,imgs)
plt.show()
<IPython.core.display.Javascript object>

1.3 使用sklearn进行线性回归

线性回归模型:

lr = sklearn.linear_model.LinearRegression(fit_intercept=True, normalize=False, copy_X=True, n_jobs=1)

  • fit_intercept:默认为True,是否计算模型的截距,否则数据中心化处理。
  • normalize:默认为False,是否中心化处理,或者使用sklearn.preprocessing.StandardScaler()
  • copy_X:默认为True,否则X会被改写。
  • n_jobs:默认为1,使用cpu的个数。

调用方法:

  • coef_:训练后的输入端模型系数。
  • intercept_:截距
  • predict(x):预测数据
  • score:评估
#利用sklearn进行线性问题求解,即系数的求解。
from sklearn import linear_model
from sklearn import datasets
import numpy as np

d = datasets.load_diabetes()

X = d.data[:,np.newaxis,2]
Y = d.target
print('X:',X.shape)
print('Y:',Y.shape)

#线性模型的具体使用方法见上文。
regr = linear_model.LinearRegression()
regr.fit(X,Y)

a,b = regr.coef_,regr.intercept_
print('a= %f,b= %f' % (a,b))

x_min = np.min(X)
x_max = np.max(X)
y_min = a*x_min+b
y_max = a*x_max+b

plt.scatter(X,Y)
plt.plot([x_min,x_max],[y_min,y_max],'r')
plt.show()
X: (442, 1)
Y: (442,)
a= 949.435260,b= 152.133484



<IPython.core.display.Javascript object>

2. 探究线性回归拟合多项式函数原理

使用sklearn拟合多项式函数

PolynomialFeatures进行特征的构造(特征与特征相乘),该函数有三个参数:

  • degree:控制多项式的次数;
  • interaction_only:默认为False,如果指定为True,那么就不会有特征自己和自己的结合项,组合特征中没 a 2 a^2 a2 b 2 b^2 b2
  • include_bias:默认为True,如果为True的话,结果中就会有0次幂项,即全为1这一列。

Pipeline函数(流水线),把一系列的类连成一条流水线,然后让数据在流水线上“跑起来”。
class sklearn.pipeline.Pipeline(steps,memory=None,verbose=False):

  • 设定流水线上一道道工序,并给一道道工序起一个名字。[(),()]类型,列表里面是一个个元组,分别为名字和工序,实现fit/transform的列表,最后一个对象是评估器。
  • memory,类型:None、str等,用于缓存管道的被训练的转换器。默认情况下不执行缓存。可以用字符串指定缓存的路径。
  • verbose:默认为False,用来显示每个流水线所消耗的时间,默认不显示。

调用的参数:

  • names_steps:只读属性,可根据用户给定的名字访问任何步骤,键是步骤名,值是步骤的参数。
from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import LinearRegression
from sklearn.pipeline import Pipeline

t = np.array([2,4,6,8])

pa = -20
pb  = 90
pc = 800

y = pa*t**2 + pb*t +pc

#通过pipeline快速实现多项式回归。
model = Pipeline([('poly',PolynomialFeatures(degree=2)),
                  ('linear',LinearRegression(fit_intercept=False))])
model = model.fit(t[:,np.newaxis],y)
model.named_steps['linear'].coef_
array([800.,  90., -20.])

Pipeline将算法连接起来。

%matplotlib inline

import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
from sklearn.preprocessing import PolynomialFeatures
from sklearn.pipeline import Pipeline

#从-3,3中随机取值
x = np.random.uniform(-3,3,size = 100)
#转换成一列
X1 = x.reshape(-1,1)
x
array([ 0.64084587, -0.11502597, -1.24388168,  0.36996037, -1.99241058,
       -1.15287568,  0.95485448,  0.43372833,  0.12157036, -0.94437214,
        1.47413249,  2.04227617, -1.36380038, -1.26772816, -0.55872249,
       -1.21782647,  1.74740853,  1.84385472, -0.98296742,  2.97871352,
        2.72478343, -1.33688328, -0.8511405 ,  0.98721063, -2.6831202 ,
       -2.40013065, -1.3489865 , -1.73266456,  1.60959222,  0.21356265,
        0.3305497 ,  1.65001695, -1.31356888, -2.00218193, -0.31918498,
        0.29642117,  2.65375127,  0.9689322 ,  1.52712763,  2.80490711,
       -2.51476141,  1.09974238, -2.59777836,  1.61336388,  2.92498289,
        1.49494967, -1.44983385, -0.36288847, -0.23628546,  2.24890589,
       -2.07196042,  0.59035022,  0.8199342 , -0.80955785,  0.18998966,
       -1.46055554, -0.92640668,  0.46121942,  0.93280307, -1.5079657 ,
       -0.97797343,  0.53090615,  1.8992068 ,  1.09406365,  1.14033716,
        2.46727101, -2.62274806,  2.78741218, -2.62243954,  0.14901036,
        1.34764166, -1.85746619,  1.50709718,  2.2149727 , -0.44419329,
        0.44752385, -0.16081948,  0.14776611, -0.49505029, -1.090961  ,
       -2.60175714,  0.41914586, -0.58771057,  2.20514719,  1.30988586,
       -1.09294743, -0.52928834,  1.6574256 ,  1.43164976,  1.77197508,
       -1.89001765,  0.97692067, -0.86728969,  0.9114925 ,  1.73473852,
       -0.78685291, -1.94921809,  2.30754732,  2.94786966,  1.37511611])
X1[0:5,:]
array([[ 0.64084587],
       [-0.11502597],
       [-1.24388168],
       [ 0.36996037],
       [-1.99241058]])
y1 = 0.5*x**2 + x +2 + np.random.normal(0,2,size=100)
plt.scatter(x,y1)
<matplotlib.collections.PathCollection at 0x2b1aa13c1d0>

在这里插入图片描述

y1
array([ 0.96818306,  2.23883725,  1.92468482, -2.71637306,  4.45307092,
        0.01816149,  1.83609337,  2.84743306,  1.6606926 ,  3.32939878,
        5.13270614,  6.94994308,  5.26852484,  0.776832  , -3.68122139,
       -0.5092344 ,  1.73890044,  8.53156829,  4.27045319, 10.39202826,
       11.0183172 , -1.88915938, -1.60786658,  1.98119282,  3.45607561,
        5.52254344,  2.62589771, -0.964355  ,  6.13707036,  1.13207466,
        1.00973905,  4.38597909,  0.8282406 ,  0.31836746,  2.52730438,
        6.13836595,  7.78835002,  6.40451125,  6.97002402,  7.66174213,
        1.70348611,  5.42996596,  1.67880113,  1.21632926,  9.18677169,
        5.29843559, -0.67039048,  3.26823983,  1.46620622,  3.56803585,
       -2.83134549,  5.75070392,  1.0478809 ,  0.30132542,  1.9288644 ,
        4.98144168,  3.62208072, -0.06874818,  4.21452823,  4.88247452,
        3.85579813,  1.16749376,  7.62196597,  4.08134595,  0.73896439,
        7.45182763,  4.77644065,  8.85020327,  4.1570157 ,  1.23410209,
        9.68626177, -1.28932949,  5.22198367,  9.94091774,  1.9535322 ,
        1.64324317,  0.28503905,  2.34850118,  2.47608781,  1.7209606 ,
        0.26932514,  2.81809854,  1.31506573,  5.54638903,  0.90487236,
        3.36809024,  0.43599388,  1.11588937,  2.35450458, 10.34477969,
        1.64070228,  4.89975885,  0.0448951 ,  2.01677536,  6.01236522,
        1.80500177,  2.37439515,  3.61958116,  8.90862482,  0.606599  ])
#建立线性回归模型。
lr = LinearRegression()
#训练模型。
lr.fit(X1,y1)
#使用模型进行预测。
y_predict = lr.predict(X1)

plt.scatter(x,y1)
plt.plot(np.sort(x),y_predict[np.argsort(x)])
[<matplotlib.lines.Line2D at 0x2b1ab684860>]

在这里插入图片描述

poly = PolynomialFeatures(degree=2)
#为总特征添加二次特征进行一元二次拟合。
poly.fit(X1)
X2 = poly.transform(X1)

X2[0:5,:]
array([[ 1.        ,  0.64084587,  0.41068343],
       [ 1.        , -0.11502597,  0.01323097],
       [ 1.        , -1.24388168,  1.54724163],
       [ 1.        ,  0.36996037,  0.13687068],
       [ 1.        , -1.99241058,  3.96969993]])
#可以看出上面的参数[1,a,a^2]。
0.64084587**2
0.41068342909605693
X1[0:5,:]
array([[ 0.64084587],
       [-0.11502597],
       [-1.24388168],
       [ 0.36996037],
       [-1.99241058]])
lr.fit(X2,y1)
y_predict2 = lr.predict(X2)

plt.scatter(x,y1)
#将数据进行一一对应绘图。
plt.plot(np.sort(x),y_predict2[np.argsort(x)])
[<matplotlib.lines.Line2D at 0x2b1ab85ec18>]

在这里插入图片描述

#将x从小到大进行排序,返回其排序后对应元素在原序列中的下标。
np.argsort(x)
array([24, 66, 68, 80, 42, 40, 25, 50, 33,  4, 96, 90, 71, 27, 59, 55, 46,
       12, 26, 21, 32, 13,  2, 15,  5, 85, 79, 18, 60,  9, 56, 92, 22, 53,
       95, 82, 14, 86, 78, 74, 47, 34, 48, 76,  1,  8, 77, 69, 54, 29, 35,
       30,  3, 81,  7, 75, 57, 61, 51,  0, 52, 93, 58,  6, 37, 91, 23, 63,
       41, 64, 84, 70, 99, 88, 10, 45, 72, 38, 28, 43, 31, 87, 94, 16, 89,
       17, 62, 11, 83, 73, 49, 97, 65, 36, 20, 67, 39, 44, 98, 19],
      dtype=int64)
#对序列进行从小到大的排序。
np.sort(x)
array([-2.6831202 , -2.62274806, -2.62243954, -2.60175714, -2.59777836,
       -2.51476141, -2.40013065, -2.07196042, -2.00218193, -1.99241058,
       -1.94921809, -1.89001765, -1.85746619, -1.73266456, -1.5079657 ,
       -1.46055554, -1.44983385, -1.36380038, -1.3489865 , -1.33688328,
       -1.31356888, -1.26772816, -1.24388168, -1.21782647, -1.15287568,
       -1.09294743, -1.090961  , -0.98296742, -0.97797343, -0.94437214,
       -0.92640668, -0.86728969, -0.8511405 , -0.80955785, -0.78685291,
       -0.58771057, -0.55872249, -0.52928834, -0.49505029, -0.44419329,
       -0.36288847, -0.31918498, -0.23628546, -0.16081948, -0.11502597,
        0.12157036,  0.14776611,  0.14901036,  0.18998966,  0.21356265,
        0.29642117,  0.3305497 ,  0.36996037,  0.41914586,  0.43372833,
        0.44752385,  0.46121942,  0.53090615,  0.59035022,  0.64084587,
        0.8199342 ,  0.9114925 ,  0.93280307,  0.95485448,  0.9689322 ,
        0.97692067,  0.98721063,  1.09406365,  1.09974238,  1.14033716,
        1.30988586,  1.34764166,  1.37511611,  1.43164976,  1.47413249,
        1.49494967,  1.50709718,  1.52712763,  1.60959222,  1.61336388,
        1.65001695,  1.6574256 ,  1.73473852,  1.74740853,  1.77197508,
        1.84385472,  1.8992068 ,  2.04227617,  2.20514719,  2.2149727 ,
        2.24890589,  2.30754732,  2.46727101,  2.65375127,  2.72478343,
        2.78741218,  2.80490711,  2.92498289,  2.94786966,  2.97871352])
  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值