首先明确我们要做的事:产生一个y = k * x + b的模型,然后模拟得到k和b的值。
使用numpy产生100个随机点x_data = np.random.rand(100)
使用numpy产生随机点的波动值y_data = x_data * 0.8 + (0.2 + 0.2*np.random.rand(100))
构造这个线性模型
b = tf.Variable(0.)
k = tf.Variable(0.)
y = k * x_data + b
构造一个损失函数,使用二次代价函数,loss = tf.reduce_mean(tf.square(y_data - y)),理论的东西来源一个公式:
然后对loss函数求偏导,就可以找到临界点。
定义一个梯度下降法来训练的优化器,这里选择梯度下降optimizer = tf.train.GradientDescentOptimizer(0.2)
梯度下降的问题如果不理解也可以用求导数类比,就好像是二次函数求导可以找出函数递减到最低点的地方。
优化最小代价函数train = optimizer.minimize(loss),就是对损失函数最小化,不理解的话也可以类比为物理中一个有阻尼效应的弹簧振子的振幅,慢慢变小。
全部代码:(为了方便观察我打印出x_data和y_data的值)
import tensorflow as tf
import numpy as np
#使用numpy产生100个随机点
x_data = np.random.rand(100)
print(x_data)
#使用numpy产生随机点的波动值
y_data = x_data * 0.8 + (0.2 + 0.2*np.random.rand(100))
print(y_data)
print(y_data - (x_data * 0.8))
#构造这个线性模型
b = tf.Variable(0.)
k = tf.Variable(0.)
y = k * x_data + b
#构造一个损失函数,使用二次代价函数
loss = tf.reduce_mean(tf.square(y_data - y))
#定义一个梯度下降法来训练的优化器,这里选择梯度下降
optimizer = tf.train.GradientDescentOptimizer(0.2)
#优化最小代价函数
train = optimizer.minimize(loss)
#初始化变量
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
for step in range(1001):
sess.run(train)
if step % 20 == 0:
print(step,sess.run([k,b]))
打印x_data的结果是:
[ 7.44110800e-01 5.43004562e-01 1.40464610e-01 1.81464847e-01 6.39729239e-02 8.55839985e-01 8.57288253e-01 5.97521668e-01 2.78126750e-01 6.33273527e-01 7.36156567e-01 1.04051817e-01 8.88657006e-01 6.59313727e-01 5.90586641e-01 2.20599954e-01 4.35010023e-02 5.94651847e-01 1.34688295e-03 7.61364799e-01 2.40110638e-01 4.30085960e-01 9.29408073e-01 2.83296015e-01 9.66464568e-01 5.64793967e-01 2.89583567e-02 4.78895281e-01 6.55705166e-01 9.39931048e-01 9.22667199e-02 4.17453020e-01 8.18261854e-01 7.51253933e-01 8.34488361e-01 5.47085942e-01 2.20812245e-01 3.64683210e-01 2.72265898e-01 6.99546767e-01 5.17783069e-01 6.86348141e-01 8.26539294e-02 7.21142415e-01 5.25883873e-01 7.95120312e-01 5.53413821e-02 7.85643413e-01 6.00917521e-01 1.00047819e-01 2.25730457e-01 3.51933714e-01 3.02575025e-01 2.47871839e-02 7.68447228e-01 5.07069502e-01 2.80167969e-01 3.69932599e-01 2.12162593e-01 9.94378802e-01 1.15214455e-01 6.23622264e-02 7.36749874e-01 9.28610319e-01 5.70055090e-01 1.53292088e-01 3.41733503e-02 1.61350956e-01 1.39131378e-01 5.95169150e-01 6.73270412e-01 3.49607503e-01 6.76886118e-01 6.85209452e-01 2.25242789e-01 2.17403176e-01 4.76698614e-01 3.66161202e-01 8.71111576e-01 8.51866555e-01 2.42937403e-02 7.98989890e-01 1.56342752e-01 2.19081202e-01 2.68967638e-01 3.66510600e-01 5.63911914e-01 3.51834652e-01 3.53287522e-01 6.38729594e-01 8.64768242e-01 6.91163778e-01 4.01258574e-01 4.18343511e-01 1.04330897e-04 9.33082029e-02 3.62988647e-01 1.90686375e-01 1.37897024e-01 6.80795678e-01]
[ 0.89765157 0.78318914 0.51019381 0.37265281 0.4484892 0.89407551 1.05607533 0.81662998 0.60433588 0.84483753 0.81586059 0.28435128 1.05519893 0.82464034 0.79314064 0.40467223 0.27353902 0.83144114 0.22773844 0.81602484 0.49884444 0.6495945 1.09878211 0.55309785 1.16746735 0.82575331 0.26850707 0.73932869 0.91451223 0.95979475 0.42978122 0.64909303 0.95775075 0.82245154 0.99671646 0.71105713 0.37859923 0.59687127 0.45301511 0.91201781 0.70358415 0.81903768 0.33778258 0.96624705 0.70700762 1.02552349 0.32549882 0.83025993 0.79845534 0.36227279 0.56073625 0.61983059 0.62606028 0.24095358 0.96142484 0.65170733 0.53331155 0.52622143 0.44052906 1.1054362 0.40243692 0.30179491 0.91216599 1.13082724 0.80231909 0.50341774 0.23161131 0.35105119 0.32236377 0.82267127 0.91461553 0.56597076 0.76393095 0.9008792 0.41224933 0.39897747 0.58644554 0.51947146 1.06847105 0.91237405 0.34874716 0.83959344 0.51669053 0.46544819 0.56874802 0.6768135 0.76004207 0.63262638 0.62986198 0.84711002 1.05019422 0.86545662 0.61937746 0.56008459 0.39665551 0.39971446 0.51636406 0.46776067 0.36391908 0.83299141]
打印(y_data - (x_data * 0.8))的结果是:
[ 0.30236293 0.34878549 0.39782212 0.22748093 0.39731086 0.20940352 0.37024473 0.33861265 0.38183448 0.33821871 0.22693534 0.20110982 0.34427332 0.29718936 0.32067133 0.22819227 0.23873822 0.35571967 0.22666093 0.206933 0.30675593 0.30552573 0.35525566 0.32646104 0.3942957 0.37391813 0.24534038 0.35621247 0.3899481 0.20784991 0.35596785 0.31513061 0.30314127 0.22144839 0.32912577 0.27338838 0.20194944 0.3051247 0.23520239 0.35238039 0.28935769 0.26995916 0.27165944 0.38933312 0.28630053 0.38942724 0.28122571 0.2017452 0.31772132 0.28223454 0.38015189 0.33828362 0.38400026 0.22112384 0.34666705 0.24605172 0.30917718 0.23027535 0.27079898 0.30993316 0.31026536 0.25190513 0.32276609 0.38793899 0.34627502 0.38078407 0.20427263 0.22197043 0.21105867 0.34653595 0.3759992 0.28628476 0.22242205 0.35271164 0.2320551 0.22505493 0.20508665 0.22654249 0.37158179 0.23088081 0.32931217 0.20040152 0.39161633 0.29018323 0.3535739 0.38360502 0.30891254 0.35115866 0.34723196 0.33612635 0.35837963 0.3125256 0.2983706 0.22540979 0.39657205 0.32506789 0.22597314 0.31521157 0.25360146 0.28835487]
最后k和b的结果是:
0 [0.14778559, 0.26584759] 20 [0.51506805, 0.441048] 40 [0.6482842, 0.37625444] 60 [0.7260837, 0.33841407] 80 [0.77151942, 0.31631491] 100 [0.7980544, 0.30340874] 120 [0.81355113, 0.29587138] 140 [0.82260132, 0.29146951] 160 [0.82788676, 0.28889877] 180 [0.83097351, 0.28739744] 200 [0.83277613, 0.28652063] 220 [0.83382899, 0.28600857] 240 [0.83444381, 0.28570953] 260 [0.83480293, 0.28553486] 280 [0.83501256, 0.2854329] 300 [0.8351351, 0.2853733] 320 [0.83520657, 0.28533852] 340 [0.83524829, 0.28531826] 360 [0.83527273, 0.28530636] 380 [0.83528692, 0.28529945] 400 [0.8352952, 0.2852954] 420 [0.83530003, 0.28529307] 440 [0.83530289, 0.28529167] 460 [0.8353045, 0.2852909] 480 [0.83530569, 0.28529033] 500 [0.83530593, 0.28529021] 520 [0.83530593, 0.28529021] 540 [0.83530593, 0.28529021] 560 [0.83530593, 0.28529021] 580 [0.83530593, 0.28529021] 600 [0.83530593, 0.28529021] 620 [0.83530593, 0.28529021] 640 [0.83530593, 0.28529021] 660 [0.83530593, 0.28529021] 680 [0.83530593, 0.28529021] 700 [0.83530593, 0.28529021] 720 [0.83530593, 0.28529021] 740 [0.83530593, 0.28529021] 760 [0.83530593, 0.28529021] 780 [0.83530593, 0.28529021] 800 [0.83530593, 0.28529021] 820 [0.83530593, 0.28529021] 840 [0.83530593, 0.28529021] 860 [0.83530593, 0.28529021] 880 [0.83530593, 0.28529021] 900 [0.83530593, 0.28529021] 920 [0.83530593, 0.28529021] 940 [0.83530593, 0.28529021] 960 [0.83530593, 0.28529021] 980 [0.83530593, 0.28529021] 1000 [0.83530593, 0.28529021]
其实从结果看,500次迭代之后就没有意义了。机器学习研究的一个重点就是,不要白白浪费计算机,浪费时间。