python深度学习线性单元(jupyter note)

python深度学习线性单元(note)
https://www.zybuluo.com/hanbingtao/note/448086

In [3]:

#线性单元将返回一个实数值而不是0,1分类。因此线性单元用来解决回归问题而不是分类问题。
#当预测工资时,如果考虑更多的因素,比如所处的行业、公司、职级等等,可能预测就会靠谱的多。把工作年限、行业、公司、职级这些信息,称之为特征。
#对于一个工作了5年,在IT行业,百度工作,职级T6这样的人,我们可以用这样的一个特征向量来表示他= (5, IT, 百度, T6)。
#既然输入变成了一个具备四个特征的向量,相对应的,仅仅一个参数就不够用了,应该使用4个参数,每个特征对应一个。
# 重要概念1:误差 e = ((labeled_y - predicted_y)^2 )/2 ; E = sum(e)
# 重要概念2:优化, 模型训练找到合适的权重。
# 重要概念3:梯度下降。梯度是一个向量,它指向函数值上升最快的方向。显然,梯度的反方向当然就是函数值下降最快的方向了。
#          我们每次沿着梯度相反方向去修改的值,当然就能走到函数的最小值附近。之所以是最小值附近而不是最小值那个点,
#          是因为我们每次移动的步长不会那么恰到好处,有可能最后一次迭代走远了越过了最小值那个点。
#          步长的选择是门手艺,如果选择小了,那么就会迭代很多轮才能走到最小值附近;如果选择大了,
#          那可能就会越过最小值很远,收敛不到一个好的点上。
#  随机梯度下降算法(Stochastic Gradient Descent, SGD)
from functools import reduce
from  DL import Perceptron #继承 perception已经定义的class
f = lambda x: x
class LinearUnit(Perceptron):
    def __init__(self,input_num):
        Perceptron.__init__(self,input_num,f)

Output:

In [4]:

def get_training_dataset():
        '''
        捏造5个人的收入数据
        '''
        # 构建训练数据
        # 输入向量列表,每一项是工作年限
        input_vecs = [[5], [3], [8], [1.4], [10.1]]
        # 期望的输出列表,月薪,注意要与输入一一对应
        labels = [5500, 2300, 7600, 1800, 11400]
        return input_vecs, labels    
    def train_linear_unit():
        '''
        使用数据训练线性单元
        '''
        # 创建感知器,输入参数的特征数为1(工作年限)
        lu = LinearUnit(1)
        # 训练,迭代10轮, 学习速率为0.01
        input_vecs, labels = get_training_dataset()
        lu.train(input_vecs, labels, 10, 0.01)
        #返回训练好的线性单元
        return lu
    if __name__ == '__main__': 
        '''训练线性单元'''
        linear_unit = train_linear_unit()
        # 打印训练获得的权重
        print(linear_unit)
        # 测试
        print('Work 3.4 years, monthly salary = %.2f' % linear_unit.predict([3.4]))
        print('Work 15 years, monthly salary = %.2f' % linear_unit.predict([15]))
        print('Work 1.5 years, monthly salary = %.2f' % linear_unit.predict([1.5]))
        print('Work 6.3 years, monthly salary = %.2f' % linear_unit.predict([6.3]))

Output:

intermediate weights:
[275.0]
intermediate weights:
[317.6]
intermediate weights:
[716.8]
intermediate weights:
[726.28332]
intermediate weights:
[1124.0884514680001]
1 iteration done!
==========
intermediate weights:
[1109.8033162670001]
intermediate weights:
[1074.0489152137761]
intermediate weights:
[982.618786600528]
intermediate weights:
[986.6126671048553]
intermediate weights:
[1117.2362469327707]
2 iteration done!
==========
intermediate weights:
[1105.1850692309142]
intermediate weights:
[1070.145450485145]
intermediate weights:
[981.9921853012403]
intermediate weights:
[986.1288757305607]
intermediate weights:
[1118.1773386013936]
3 iteration done!
==========
intermediate weights:
[1106.344907106921]
intermediate weights:
[1071.4720019497904]
intermediate weights:
[983.188228792761]
intermediate weights:
[987.4274400273441]
intermediate weights:
[1119.0525762457441]
4 iteration done!
==========
intermediate weights:
[1107.4496383368378]
intermediate weights:
[1072.7449122054936]
intermediate weights:
[984.35560460807]
intermediate weights:
[988.6962173835135]
intermediate weights:
[1119.9163650519292]
5 iteration done!
==========
intermediate weights:
[1108.5397280189734]
intermediate weights:
[1074.0008848574794]
intermediate weights:
[985.507307326902]
intermediate weights:
[989.94795090016]
intermediate weights:
[1120.768490220792]
6 iteration done!
==========
intermediate weights:
[1109.61509966593]
intermediate weights:
[1075.2399002395273]
intermediate weights:
[986.6434615203494]
intermediate weights:
[991.1827854932801]
intermediate weights:
[1121.6091117600865]
7 iteration done!
==========
intermediate weights:
[1110.6759538625565]
intermediate weights:
[1076.4621889878067]
intermediate weights:
[987.7642776932844]
intermediate weights:
[992.4009498853544]
intermediate weights:
[1122.43838495064]
8 iteration done!
==========
intermediate weights:
[1111.722486580645]
intermediate weights:
[1077.6679768992058]
intermediate weights:
[988.8699629040694]
intermediate weights:
[993.6026691192034]
intermediate weights:
[1123.2564629947137]
9 iteration done!
==========
intermediate weights:
[1112.7548911595277]
intermediate weights:
[1078.8574867342768]
intermediate weights:
[989.9607214199639]
intermediate weights:
[994.7881652036668]
intermediate weights:
[1124.0634970262222]
10 iteration done!
==========
weights	:[1124.0634970262222]
bias	:85.485289

Work 3.4 years, monthly salary = 3907.30
Work 15 years, monthly salary = 16946.44
Work 1.5 years, monthly salary = 1771.58
Work 6.3 years, monthly salary = 7167.09

In [11]:

### import matplotlib as plt
%pylab inline
plt.plot([3.4,15,1.5,6.3],[3907,16946,1771,7169],'*')
plt.plot([0,20],[85.48,1124.06*20+85.48])
plt.show()

Output:

Populating the interactive namespace from numpy and matplotlib

在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值