tensorflow 6. 动态计算图实现线性回归(eager execution)

65 篇文章 5 订阅
34 篇文章 4 订阅

本例程源码来自这里

目前我已把我自己手工敲写加注释的代码放到自己的github账户上面,项目地址在这里:https://github.com/RootYuan/tensorflow_examples_practice/

下面是正文分割线


tensorflow一直以来是基于静态计算图的,这其实跟程序的执行过程并不一致,没办法使用python语言取控制中间流程。PyTorch 的动态图一直是 TensorFlow用户求之不得的功能。tensorflow从v1.5版本加入动态图,目前为止已更新到v1.7版本,动态图得到了进一步完善。

今天主要运行一下动态图版本的线性回归。静态计算图一般是先搭建图结构,然后使用sess.run填入数据并并运行优化器。动态图虽然思路类似,但是不在需要sess和显示grap了,就像调用函数一样方便。

另外,使用动态库开始需要import tensorflow.contrib.eager as tfe并调用tfe.enable_eager_execution()

下面是源码,我主要为了感性认识,还有很多知识点没有弄清楚,后续遇到再继续。

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import tensorflow.contrib.eager as tfe

# 开启eager API
tfe.enable_eager_execution()

# 训练数据
train_X = [3.3, 4.4, 5.5, 6.71, 6.93, 4.168, 9.779, 6.182, 7.59, 2.167,
           7.042, 10.791, 5.313, 7.997, 5.654, 9.27, 3.1]
train_Y = [1.7, 2.76, 2.09, 3.19, 1.694, 1.573, 3.366, 2.596, 2.53, 1.221,
           2.827, 3.465, 1.65, 2.904, 2.42, 2.94, 1.3]
n_samples = len(train_X)

# 参数
learning_rate = 0.01
display_step = 100
num_steps = 1000

# 权重和偏置
W = tfe.Variable(np.random.randn())
b = tfe.Variable(np.random.randn())

# 线性回归公式函数(Wx + b)
def linear_regression(inputs):
    return inputs * W + b

# 均方误差函数,计算损失
def mean_square_fn(model_fn, inputs, labels):
    return tf.reduce_sum(tf.pow(model_fn(inputs) - labels, 2)) / (2 * n_samples)

# 随机梯度下降法作为优化器
optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)
# 计算梯度
grad = tfe.implicit_gradients(mean_square_fn)

# 优化之前,初始化损失函数
print("Initial cost= {:.9f}".format(
    mean_square_fn(linear_regression, train_X, train_Y)),
    "W=", W.numpy(), "b=", b.numpy())

# 训练
for step in range(num_steps):

    optimizer.apply_gradients(grad(linear_regression, train_X, train_Y))

    if (step + 1) % display_step == 0 or step == 0:
        print( "Epoch:", '%04d' % (step + 1), "cost=",
               "{:.9f}".format( mean_square_fn( linear_regression, train_X, train_Y ) ),
               "W=", W.numpy(), "b=", b.numpy() )

# 图表显示
plt.plot(train_X, train_Y, 'ro', label='Original data')
plt.plot(train_X, np.array(W * train_X + b), label='Fitted line')
plt.legend()
plt.show()

终端输出:

Initial cost= 2.973774910 W= 0.8843342 b= -1.2450998
Epoch: 0001 cost= 1.136796951 W= 0.7314757 b= -1.2640716
Epoch: 0100 cost= 0.289310157 W= 0.51290995 b= -1.053521
Epoch: 0200 cost= 0.243507951 W= 0.4830278 b= -0.84167004
Epoch: 0300 cost= 0.207583487 W= 0.4565633 b= -0.6540487
Epoch: 0400 cost= 0.179406464 W= 0.43312556 b= -0.48788548
Epoch: 0500 cost= 0.157306090 W= 0.41236836 b= -0.34072652
Epoch: 0600 cost= 0.139971867 W= 0.3939852 b= -0.21039806
Epoch: 0700 cost= 0.126376018 W= 0.37770453 b= -0.09497542
Epoch: 0800 cost= 0.115712211 W= 0.36328587 b= 0.0072462982
Epoch: 0900 cost= 0.107348159 W= 0.35051632 b= 0.097776845
Epoch: 1000 cost= 0.100787930 W= 0.3392072 b= 0.17795336
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值