本文仅作为学习参考文章第二节关于线性单元和梯度下降的一个补充,主要是代码部分python画图从python2改到python3,无其他改动
关于梯度下降和反向传播之前的博文已经写过,此处不再详述
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
# 其实代码同参考文章作者所说,只不过是去掉了激活函数,也就是分类的过程,其他并无任何区别,多了画图部分也只是方便理解。
from perceptron import Perceptron
# 定义激活函数f
f = lambda x: x
class LinearUnit(Perceptron):
def __init__(self, input_num):
'''初始化线性单元,设置输入参数的个数'''
Perceptron.__init__(self, input_num, f)
def get_training_dataset():
'''
捏造5个人的收入数据
'''
# 构建训练数据
# 输入向量列表,每一项是工作年限
input_vecs = [[5], [3], [8], [1.4], [10.1]]
# 期望的输出列表,月薪,注意要与输入一一对应
labels = [5500, 2300, 7600, 1800, 11400]
return input_vecs, labels
def train_linear_unit():
'''
使用数据训练线性单元
'''
# 创建感知器,输入参数的特征数为1(工作年限)
lu = LinearUnit(1)
# 训练,迭代10轮, 学习速率为0.01
input_vecs, labels = get_training_dataset()
lu.train(input_vecs, labels, 10, 0.01)
# 返回训练好的线性单元
return lu
def plot(linear_unit):
import matplotlib.pyplot as plt
input_vecs, labels = get_training_dataset()
fig = plt.figure()
# 分区域,111表示第一行第一列的第一个区域
ax = fig.add_subplot(111)
# scatter 散点图
ax.scatter(list(map(lambda x: x[0], input_vecs)), labels)
weights = linear_unit.weights
bias = linear_unit.bias
x = range(0,12,1)
y = list(map(lambda x:weights[0] * x + bias, x))
ax.plot(x, y)
plt.show()
if __name__ == '__main__':
'''训练线性单元'''
linear_unit = train_linear_unit()
# 打印训练获得的权重
print(linear_unit)
# 测试
print('Work 3.4 years, monthly salary = %.2f' % linear_unit.predict([3.4]))
print('Work 15 years, monthly salary = %.2f' % linear_unit.predict([15]))
print('Work 1.5 years, monthly salary = %.2f' % linear_unit.predict([1.5]))
print('Work 6.3 years, monthly salary = %.2f' % linear_unit.predict([6.3]))
plot(linear_unit)
代码运行结果和图示分别为:
输入,权重 [5] [0.0]
输入,输出 权重 偏移项 delta [5] 0.0 [275.0] 55.0 5500.0
输入,权重 [3] [275.0]
输入,输出 权重 偏移项 delta [3] 880.0 [317.6] 69.2 1420.0
输入,权重 [8] [317.6]
输入,输出 权重 偏移项 delta [8] 2610.0 [716.8] 119.1 4990.0
输入,权重 [1.4] [716.8]
输入,输出 权重 偏移项 delta [1.4] 1122.62 [726.28332] 125.87379999999999 677.3800000000001
输入,权重 [10.1] [726.28332]
输入,输出 权重 偏移项 delta [10.1] 7461.335332 [1124.0884514680001] 165.26044668 3938.6646680000003
输入,权重 [5] [1124.0884514680001]
输入,输出 权重 偏移项 delta [5] 5785.702704020001 [1109.8033162670001] 162.40341963979998 -285.70270402000097
输入,权重 [3] [1109.8033162670001]
输入,输出 权重 偏移项 delta [3] 3491.813368440801 [1074.0489152137761] 150.48528595539196 -1191.8133684408008
输入,权重 [8] [1074.0489152137761]
输入,输出 权重 偏移项 delta [8] 8742.876607665601 [982.618786600528] 139.05651987873594 -1142.8766076656011
输入,权重 [1.4] [982.618786600528]
输入,输出 权重 偏移项 delta [1.4] 1514.7228211194752 [986.6126671048553] 141.9092916675412 285.27717888052484
输入,权重 [10.1] [986.6126671048553]
输入,输出 权重 偏移项 delta [10.1] 10106.69722942658 [1117.2362469327707] 154.8423193732754 1293.3027705734203
输入,权重 [5] [1117.2362469327707]
输入,输出 权重 偏移项 delta [5] 5741.023554037129 [1105.1850692309142] 152.4320838329041 -241.02355403712863
输入,权重 [3] [1105.1850692309142]
输入,输出 权重 偏移项 delta [3] 3467.9872915256465 [1070.145450485145] 140.75221091764766 -1167.9872915256465
输入,权重 [8] [1070.145450485145]
输入,输出 权重 偏移项 delta [8] 8701.915814798807 [981.9921853012403] 129.7330527696596 -1101.9158147988073
输入,权重 [1.4] [981.9921853012403]
输入,输出 权重 偏移项 delta [1.4] 1504.522112191396 [986.1288757305607] 132.68783164774564 295.477887808604
输入,权重 [10.1] [986.1288757305607]
输入,输出 权重 偏移项 delta [10.1] 10092.589476526407 [1118.1773386013936] 145.76193688248156 1307.4105234735925
输入,权重 [5] [1118.1773386013936]
输入,输出 权重 偏移项 delta [5] 5736.64862988945 [1106.344907106921] 143.39545058358706 -236.6486298894497
输入,权重 [3] [1106.344907106921]
输入,输出 权重 偏移项 delta [3] 3462.4301719043506 [1071.4720019497904] 131.77114886454356 -1162.4301719043506
输入,权重 [8] [1071.4720019497904]
输入,输出 权重 偏移项 delta [8] 8703.547164462867 [983.188228792761] 120.73567721991489 -1103.547164462867
输入,权重 [1.4] [983.188228792761]
输入,输出 权重 偏移项 delta [1.4] 1497.1991975297801 [987.4274400273441] 123.7636852446171 302.8008024702199
输入,权重 [10.1] [987.4274400273441]
输入,输出 权重 偏移项 delta [10.1] 10096.780829520792 [1119.0525762457441] 136.79587694940918 1303.2191704792076
输入,权重 [5] [1119.0525762457441]
输入,输出 权重 偏移项 delta [5] 5732.058758178129 [1107.4496383368378] 134.47528936762788 -232.0587581781292
输入,权重 [3] [1107.4496383368378]
输入,输出 权重 偏移项 delta [3] 3456.824204378141 [1072.7449122054936] 122.90704732384647 -1156.824204378141
输入,权重 [8] [1072.7449122054936]
输入,输出 权重 偏移项 delta [8] 8704.866344967795 [984.35560460807] 111.85838387416852 -1104.8663449677952
输入,权重 [1.4] [984.35560460807]
输入,输出 权重 偏移项 delta [1.4] 1489.9562303254665 [988.6962173835135] 114.95882157091386 310.04376967453345
输入,权重 [10.1] [988.6962173835135]
输入,输出 权重 偏移项 delta [10.1] 10100.7906171444 [1119.9163650519292] 127.95091539946986 1299.2093828556008
输入,权重 [5] [1119.9163650519292]
输入,输出 权重 偏移项 delta [5] 5727.532740659116 [1108.5397280189734] 125.6755879928787 -227.53274065911592
输入,权重 [3] [1108.5397280189734]
输入,输出 权重 偏移项 delta [3] 3451.2947720497987 [1074.0008848574794] 114.16264027238071 -1151.2947720497987
输入,权重 [8] [1074.0008848574794]
输入,输出 权重 偏移项 delta [8] 8706.169719132216 [985.507307326902] 103.10094308105855 -1106.1697191322164
输入,权重 [1.4] [985.507307326902]
输入,输出 权重 偏移项 delta [1.4] 1482.8111733387213 [989.94795090016] 106.27283134767133 317.18882666127865
输入,权重 [10.1] [989.94795090016]
输入,输出 权重 偏移项 delta [10.1] 10104.747135439286 [1120.768490220792] 119.22535999327847 1295.2528645607144
输入,权重 [5] [1120.768490220792]
输入,输出 权重 偏移项 delta [5] 5723.0678110972385 [1109.61509966593] 116.99468188230608 -223.06781109723852
输入,权重 [3] [1109.61509966593]
输入,输出 权重 偏移项 delta [3] 3445.8399808800964 [1075.2399002395273] 105.53628207350512 -1145.8399808800964
输入,权重 [8] [1075.2399002395273]
输入,输出 权重 偏移项 delta [8] 8707.455483989723 [986.6434615203494] 94.46172723360789 -1107.4554839897228
输入,权重 [1.4] [986.6434615203494]
输入,输出 权重 偏移项 delta [1.4] 1475.762573362097 [991.1827854932801] 97.70410149998692 324.2374266379029
输入,权重 [10.1] [991.1827854932801]
输入,输出 权重 偏移项 delta [10.1] 10108.650234982115 [1121.6091117600865] 110.61759915016577 1291.3497650178851
输入,权重 [5] [1121.6091117600865]
输入,输出 权重 偏移项 delta [5] 5718.663157950598 [1110.6759538625565] 108.43096757065979 -218.66315795059836
输入,权重 [3] [1110.6759538625565]
输入,输出 权重 偏移项 delta [3] 3440.4588291583295 [1076.4621889878067] 97.0263792790765 -1140.4588291583295
输入,权重 [8] [1076.4621889878067]
输入,输出 权重 偏移项 delta [8] 8708.72389118153 [987.7642776932844] 85.9391403672612 -1108.7238911815293
输入,权重 [1.4] [987.7642776932844]
输入,输出 权重 偏移项 delta [1.4] 1468.8091291378591 [992.4009498853544] 89.2510490758826 331.19087086214086
输入,权重 [10.1] [992.4009498853544]
输入,输出 权重 偏移项 delta [10.1] 10112.500642917963 [1122.43838495064] 102.12604264670297 1287.499357082037
输入,权重 [5] [1122.43838495064]
输入,输出 权重 偏移项 delta [5] 5714.317967399904 [1111.722486580645] 99.98286297270393 -214.31796739990386
输入,权重 [3] [1111.722486580645]
输入,输出 权重 偏移项 delta [3] 3435.150322714639 [1077.6679768992058] 88.63135974555753 -1135.150322714639
输入,权重 [8] [1077.6679768992058]
输入,输出 权重 偏移项 delta [8] 8709.975174939203 [988.8699629040694] 77.53160799616549 -1109.9751749392035
输入,权重 [1.4] [988.8699629040694]
输入,输出 权重 偏移项 delta [1.4] 1461.9495560618627 [993.6026691192034] 80.91211243554686 338.0504439381373
输入,权重 [10.1] [993.6026691192034]
输入,输出 权重 偏移项 delta [10.1] 10116.2990705395 [1123.2564629947137] 93.74912173015186 1283.7009294604995
输入,权重 [5] [1123.2564629947137]
输入,输出 权重 偏移项 delta [5] 5710.031436703721 [1112.7548911595277] 91.64880736311464 -210.03143670372083
输入,权重 [3] [1112.7548911595277]
输入,输出 权重 偏移项 delta [3] 3429.9134808416975 [1078.8574867342768] 80.34967255469766 -1129.9134808416975
输入,权重 [8] [1078.8574867342768]
输入,输出 权重 偏移项 delta [8] 8711.209566428912 [989.9607214199639] 69.23757689040855 -1111.2095664289118
输入,权重 [1.4] [989.9607214199639]
输入,输出 权重 偏移项 delta [1.4] 1455.1825868783578 [994.7881652036668] 72.68575102162497 344.8174131216422
输入,权重 [10.1] [994.7881652036668]
输入,输出 权重 偏移项 delta [10.1] 10120.04621957866 [1124.0634970262222] 85.48528882583837 1279.9537804213396
weights :[1124.0634970262222]
bias :85.485289
输入,权重 [3.4] [1124.0634970262222]
Work 3.4 years, monthly salary = 3907.30
输入,权重 [15] [1124.0634970262222]
Work 15 years, monthly salary = 16946.44
输入,权重 [1.5] [1124.0634970262222]
Work 1.5 years, monthly salary = 1771.58
输入,权重 [6.3] [1124.0634970262222]
Work 6.3 years, monthly salary = 7167.09
参考链接
1 参考文章
2 个人对梯度下降和反向传播的理解