python深度学习线性单元(note)
https://www.zybuluo.com/hanbingtao/note/448086
In [3]:
#线性单元将返回一个实数值而不是0,1分类。因此线性单元用来解决回归问题而不是分类问题。
#当预测工资时,如果考虑更多的因素,比如所处的行业、公司、职级等等,可能预测就会靠谱的多。把工作年限、行业、公司、职级这些信息,称之为特征。
#对于一个工作了5年,在IT行业,百度工作,职级T6这样的人,我们可以用这样的一个特征向量来表示他= (5, IT, 百度, T6)。
#既然输入变成了一个具备四个特征的向量,相对应的,仅仅一个参数就不够用了,应该使用4个参数,每个特征对应一个。
# 重要概念1:误差 e = ((labeled_y - predicted_y)^2 )/2 ; E = sum(e)
# 重要概念2:优化, 模型训练找到合适的权重。
# 重要概念3:梯度下降。梯度是一个向量,它指向函数值上升最快的方向。显然,梯度的反方向当然就是函数值下降最快的方向了。
# 我们每次沿着梯度相反方向去修改的值,当然就能走到函数的最小值附近。之所以是最小值附近而不是最小值那个点,
# 是因为我们每次移动的步长不会那么恰到好处,有可能最后一次迭代走远了越过了最小值那个点。
# 步长的选择是门手艺,如果选择小了,那么就会迭代很多轮才能走到最小值附近;如果选择大了,
# 那可能就会越过最小值很远,收敛不到一个好的点上。
# 随机梯度下降算法(Stochastic Gradient Descent, SGD)
from functools import reduce
from DL import Perceptron #继承 perception已经定义的class
f = lambda x: x
class LinearUnit(Perceptron):
def __init__(self,input_num):
Perceptron.__init__(self,input_num,f)
Output:
In [4]:
def get_training_dataset():
'''
捏造5个人的收入数据
'''
# 构建训练数据
# 输入向量列表,每一项是工作年限
input_vecs = [[5], [3], [8], [1.4], [10.1]]
# 期望的输出列表,月薪,注意要与输入一一对应
labels = [5500, 2300, 7600, 1800, 11400]
return input_vecs, labels
def train_linear_unit():
'''
使用数据训练线性单元
'''
# 创建感知器,输入参数的特征数为1(工作年限)
lu = LinearUnit(1)
# 训练,迭代10轮, 学习速率为0.01
input_vecs, labels = get_training_dataset()
lu.train(input_vecs, labels, 10, 0.01)
#返回训练好的线性单元
return lu
if __name__ == '__main__':
'''训练线性单元'''
linear_unit = train_linear_unit()
# 打印训练获得的权重
print(linear_unit)
# 测试
print('Work 3.4 years, monthly salary = %.2f' % linear_unit.predict([3.4]))
print('Work 15 years, monthly salary = %.2f' % linear_unit.predict([15]))
print('Work 1.5 years, monthly salary = %.2f' % linear_unit.predict([1.5]))
print('Work 6.3 years, monthly salary = %.2f' % linear_unit.predict([6.3]))
Output:
intermediate weights:
[275.0]
intermediate weights:
[317.6]
intermediate weights:
[716.8]
intermediate weights:
[726.28332]
intermediate weights:
[1124.0884514680001]
1 iteration done!
==========
intermediate weights:
[1109.8033162670001]
intermediate weights:
[1074.0489152137761]
intermediate weights:
[982.618786600528]
intermediate weights:
[986.6126671048553]
intermediate weights:
[1117.2362469327707]
2 iteration done!
==========
intermediate weights:
[1105.1850692309142]
intermediate weights:
[1070.145450485145]
intermediate weights:
[981.9921853012403]
intermediate weights:
[986.1288757305607]
intermediate weights:
[1118.1773386013936]
3 iteration done!
==========
intermediate weights:
[1106.344907106921]
intermediate weights:
[1071.4720019497904]
intermediate weights:
[983.188228792761]
intermediate weights:
[987.4274400273441]
intermediate weights:
[1119.0525762457441]
4 iteration done!
==========
intermediate weights:
[1107.4496383368378]
intermediate weights:
[1072.7449122054936]
intermediate weights:
[984.35560460807]
intermediate weights:
[988.6962173835135]
intermediate weights:
[1119.9163650519292]
5 iteration done!
==========
intermediate weights:
[1108.5397280189734]
intermediate weights:
[1074.0008848574794]
intermediate weights:
[985.507307326902]
intermediate weights:
[989.94795090016]
intermediate weights:
[1120.768490220792]
6 iteration done!
==========
intermediate weights:
[1109.61509966593]
intermediate weights:
[1075.2399002395273]
intermediate weights:
[986.6434615203494]
intermediate weights:
[991.1827854932801]
intermediate weights:
[1121.6091117600865]
7 iteration done!
==========
intermediate weights:
[1110.6759538625565]
intermediate weights:
[1076.4621889878067]
intermediate weights:
[987.7642776932844]
intermediate weights:
[992.4009498853544]
intermediate weights:
[1122.43838495064]
8 iteration done!
==========
intermediate weights:
[1111.722486580645]
intermediate weights:
[1077.6679768992058]
intermediate weights:
[988.8699629040694]
intermediate weights:
[993.6026691192034]
intermediate weights:
[1123.2564629947137]
9 iteration done!
==========
intermediate weights:
[1112.7548911595277]
intermediate weights:
[1078.8574867342768]
intermediate weights:
[989.9607214199639]
intermediate weights:
[994.7881652036668]
intermediate weights:
[1124.0634970262222]
10 iteration done!
==========
weights :[1124.0634970262222]
bias :85.485289
Work 3.4 years, monthly salary = 3907.30
Work 15 years, monthly salary = 16946.44
Work 1.5 years, monthly salary = 1771.58
Work 6.3 years, monthly salary = 7167.09
In [11]:
### import matplotlib as plt
%pylab inline
plt.plot([3.4,15,1.5,6.3],[3907,16946,1771,7169],'*')
plt.plot([0,20],[85.48,1124.06*20+85.48])
plt.show()
Output:
Populating the interactive namespace from numpy and matplotlib