简单的Tensorflow(4):线性模型分析

首先明确我们要做的事:产生一个y = k * x + b的模型,然后模拟得到k和b的值。


使用numpy产生100个随机点x_data = np.random.rand(100)

使用numpy产生随机点的波动值y_data = x_data * 0.8 + (0.2 + 0.2*np.random.rand(100))


构造这个线性模型

b = tf.Variable(0.)

k = tf.Variable(0.)

y = k * x_data + b


构造一个损失函数,使用二次代价函数,loss = tf.reduce_mean(tf.square(y_data - y)),理论的东西来源一个公式:


然后对loss函数求偏导,就可以找到临界点。


定义一个梯度下降法来训练的优化器,这里选择梯度下降optimizer = tf.train.GradientDescentOptimizer(0.2)

梯度下降的问题如果不理解也可以用求导数类比,就好像是二次函数求导可以找出函数递减到最低点的地方。


优化最小代价函数train = optimizer.minimize(loss),就是对损失函数最小化,不理解的话也可以类比为物理中一个有阻尼效应的弹簧振子的振幅,慢慢变小。


全部代码:(为了方便观察我打印出x_data和y_data的值

import tensorflow as tf
import numpy as np

#使用numpy产生100个随机点
x_data = np.random.rand(100)
print(x_data)

#使用numpy产生随机点的波动值
y_data = x_data * 0.8 + (0.2 + 0.2*np.random.rand(100))
print(y_data)
print(y_data - (x_data * 0.8))

#构造这个线性模型
b = tf.Variable(0.)
k = tf.Variable(0.)
y = k * x_data + b

#构造一个损失函数,使用二次代价函数
loss = tf.reduce_mean(tf.square(y_data - y))

#定义一个梯度下降法来训练的优化器,这里选择梯度下降
optimizer = tf.train.GradientDescentOptimizer(0.2)

#优化最小代价函数
train = optimizer.minimize(loss)

#初始化变量
init = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init)
    for step in range(1001):
        sess.run(train)
        if step % 20 == 0:
            print(step,sess.run([k,b]))


打印x_data的结果是:

[  7.44110800e-01   5.43004562e-01   1.40464610e-01   1.81464847e-01
   6.39729239e-02   8.55839985e-01   8.57288253e-01   5.97521668e-01
   2.78126750e-01   6.33273527e-01   7.36156567e-01   1.04051817e-01
   8.88657006e-01   6.59313727e-01   5.90586641e-01   2.20599954e-01
   4.35010023e-02   5.94651847e-01   1.34688295e-03   7.61364799e-01
   2.40110638e-01   4.30085960e-01   9.29408073e-01   2.83296015e-01
   9.66464568e-01   5.64793967e-01   2.89583567e-02   4.78895281e-01
   6.55705166e-01   9.39931048e-01   9.22667199e-02   4.17453020e-01
   8.18261854e-01   7.51253933e-01   8.34488361e-01   5.47085942e-01
   2.20812245e-01   3.64683210e-01   2.72265898e-01   6.99546767e-01
   5.17783069e-01   6.86348141e-01   8.26539294e-02   7.21142415e-01
   5.25883873e-01   7.95120312e-01   5.53413821e-02   7.85643413e-01
   6.00917521e-01   1.00047819e-01   2.25730457e-01   3.51933714e-01
   3.02575025e-01   2.47871839e-02   7.68447228e-01   5.07069502e-01
   2.80167969e-01   3.69932599e-01   2.12162593e-01   9.94378802e-01
   1.15214455e-01   6.23622264e-02   7.36749874e-01   9.28610319e-01
   5.70055090e-01   1.53292088e-01   3.41733503e-02   1.61350956e-01
   1.39131378e-01   5.95169150e-01   6.73270412e-01   3.49607503e-01
   6.76886118e-01   6.85209452e-01   2.25242789e-01   2.17403176e-01
   4.76698614e-01   3.66161202e-01   8.71111576e-01   8.51866555e-01
   2.42937403e-02   7.98989890e-01   1.56342752e-01   2.19081202e-01
   2.68967638e-01   3.66510600e-01   5.63911914e-01   3.51834652e-01
   3.53287522e-01   6.38729594e-01   8.64768242e-01   6.91163778e-01
   4.01258574e-01   4.18343511e-01   1.04330897e-04   9.33082029e-02
   3.62988647e-01   1.90686375e-01   1.37897024e-01   6.80795678e-01]


打印y_data的结果是:

[ 0.89765157  0.78318914  0.51019381  0.37265281  0.4484892   0.89407551
  1.05607533  0.81662998  0.60433588  0.84483753  0.81586059  0.28435128
  1.05519893  0.82464034  0.79314064  0.40467223  0.27353902  0.83144114
  0.22773844  0.81602484  0.49884444  0.6495945   1.09878211  0.55309785
  1.16746735  0.82575331  0.26850707  0.73932869  0.91451223  0.95979475
  0.42978122  0.64909303  0.95775075  0.82245154  0.99671646  0.71105713
  0.37859923  0.59687127  0.45301511  0.91201781  0.70358415  0.81903768
  0.33778258  0.96624705  0.70700762  1.02552349  0.32549882  0.83025993
  0.79845534  0.36227279  0.56073625  0.61983059  0.62606028  0.24095358
  0.96142484  0.65170733  0.53331155  0.52622143  0.44052906  1.1054362
  0.40243692  0.30179491  0.91216599  1.13082724  0.80231909  0.50341774
  0.23161131  0.35105119  0.32236377  0.82267127  0.91461553  0.56597076
  0.76393095  0.9008792   0.41224933  0.39897747  0.58644554  0.51947146
  1.06847105  0.91237405  0.34874716  0.83959344  0.51669053  0.46544819
  0.56874802  0.6768135   0.76004207  0.63262638  0.62986198  0.84711002
  1.05019422  0.86545662  0.61937746  0.56008459  0.39665551  0.39971446
  0.51636406  0.46776067  0.36391908  0.83299141]

打印(y_data - (x_data * 0.8))的结果是:

[ 0.30236293  0.34878549  0.39782212  0.22748093  0.39731086  0.20940352
  0.37024473  0.33861265  0.38183448  0.33821871  0.22693534  0.20110982
  0.34427332  0.29718936  0.32067133  0.22819227  0.23873822  0.35571967
  0.22666093  0.206933    0.30675593  0.30552573  0.35525566  0.32646104
  0.3942957   0.37391813  0.24534038  0.35621247  0.3899481   0.20784991
  0.35596785  0.31513061  0.30314127  0.22144839  0.32912577  0.27338838
  0.20194944  0.3051247   0.23520239  0.35238039  0.28935769  0.26995916
  0.27165944  0.38933312  0.28630053  0.38942724  0.28122571  0.2017452
  0.31772132  0.28223454  0.38015189  0.33828362  0.38400026  0.22112384
  0.34666705  0.24605172  0.30917718  0.23027535  0.27079898  0.30993316
  0.31026536  0.25190513  0.32276609  0.38793899  0.34627502  0.38078407
  0.20427263  0.22197043  0.21105867  0.34653595  0.3759992   0.28628476
  0.22242205  0.35271164  0.2320551   0.22505493  0.20508665  0.22654249
  0.37158179  0.23088081  0.32931217  0.20040152  0.39161633  0.29018323
  0.3535739   0.38360502  0.30891254  0.35115866  0.34723196  0.33612635
  0.35837963  0.3125256   0.2983706   0.22540979  0.39657205  0.32506789
  0.22597314  0.31521157  0.25360146  0.28835487]

最后k和b的结果是:

0 [0.14778559, 0.26584759]
20 [0.51506805, 0.441048]
40 [0.6482842, 0.37625444]
60 [0.7260837, 0.33841407]
80 [0.77151942, 0.31631491]
100 [0.7980544, 0.30340874]
120 [0.81355113, 0.29587138]
140 [0.82260132, 0.29146951]
160 [0.82788676, 0.28889877]
180 [0.83097351, 0.28739744]
200 [0.83277613, 0.28652063]
220 [0.83382899, 0.28600857]
240 [0.83444381, 0.28570953]
260 [0.83480293, 0.28553486]
280 [0.83501256, 0.2854329]
300 [0.8351351, 0.2853733]
320 [0.83520657, 0.28533852]
340 [0.83524829, 0.28531826]
360 [0.83527273, 0.28530636]
380 [0.83528692, 0.28529945]
400 [0.8352952, 0.2852954]
420 [0.83530003, 0.28529307]
440 [0.83530289, 0.28529167]
460 [0.8353045, 0.2852909]
480 [0.83530569, 0.28529033]
500 [0.83530593, 0.28529021]
520 [0.83530593, 0.28529021]
540 [0.83530593, 0.28529021]
560 [0.83530593, 0.28529021]
580 [0.83530593, 0.28529021]
600 [0.83530593, 0.28529021]
620 [0.83530593, 0.28529021]
640 [0.83530593, 0.28529021]
660 [0.83530593, 0.28529021]
680 [0.83530593, 0.28529021]
700 [0.83530593, 0.28529021]
720 [0.83530593, 0.28529021]
740 [0.83530593, 0.28529021]
760 [0.83530593, 0.28529021]
780 [0.83530593, 0.28529021]
800 [0.83530593, 0.28529021]
820 [0.83530593, 0.28529021]
840 [0.83530593, 0.28529021]
860 [0.83530593, 0.28529021]
880 [0.83530593, 0.28529021]
900 [0.83530593, 0.28529021]
920 [0.83530593, 0.28529021]
940 [0.83530593, 0.28529021]
960 [0.83530593, 0.28529021]
980 [0.83530593, 0.28529021]
1000 [0.83530593, 0.28529021]

其实从结果看,500次迭代之后就没有意义了。机器学习研究的一个重点就是,不要白白浪费计算机,浪费时间。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值