TensorFlow 图模式入门指南:从基础到线性回归实现

TensorFlow 图模式入门指南:从基础到线性回归实现

training-data-analyst Labs and demos for courses for GCP Training (http://cloud.google.com/training). training-data-analyst 项目地址: https://gitcode.com/gh_mirrors/tr/training-data-analyst

概述

TensorFlow 作为当前最流行的机器学习框架之一,提供了两种不同的执行模式:即时执行(Eager Execution)图执行(Graph Execution)。本文将重点介绍图执行模式的工作原理及其实践应用。

两种执行模式对比

即时执行模式

  • 特点:立即执行操作并返回具体值
  • 优势:直观易用,调试方便,样板代码少
  • 适用场景:原型开发阶段

图执行模式

  • 特点:先构建计算图,然后在会话中执行
  • 优势:性能优化潜力大,适合分布式训练
  • 适用场景:生产环境中的性能敏感型应用

图执行模式基础

1. 构建计算图

在图模式下,操作不会立即执行,而是先构建一个符号化的计算图:

a = tf.constant(value = [5, 3, 8], dtype = tf.int32)
b = tf.constant(value = [3, -1, 2], dtype = tf.int32)
c = tf.add(x = a, y = b)
print(c)  # 此时只输出张量信息,不计算具体值

2. 执行计算图

需要通过tf.Session()来执行计算图:

with tf.Session() as sess:
    result = sess.run(fetches = c)
    print(result)  # 输出实际计算结果

参数化计算图

在实际应用中,我们经常需要动态输入数据。TensorFlow提供了placeholder机制来实现这一需求:

a = tf.placeholder(dtype = tf.int32, shape = [None])  
b = tf.placeholder(dtype = tf.int32, shape = [None])
c = tf.add(x = a, y = b)

with tf.Session() as sess:
    result = sess.run(fetches = c, feed_dict = {
        a: [3, 4, 5],
        b: [-1, 2, 3]
    })
    print(result)

线性回归实战

1. 准备数据

我们使用简单的线性关系作为示例: y = 2x + 10

X = tf.constant(value = [1,2,3,4,5,6,7,8,9,10], dtype = tf.float32)
Y = 2 * X + 10

2. 定义模型和损失函数

使用变量来存储模型参数,并定义均方误差(MSE)作为损失函数:

with tf.variable_scope(name_or_scope = "training", reuse = tf.AUTO_REUSE):
    w0 = tf.get_variable(name = "w0", initializer = tf.constant(value = 0.0, dtype = tf.float32))
    w1 = tf.get_variable(name = "w1", initializer = tf.constant(value = 0.0, dtype = tf.float32))
    
Y_hat = w0 * X + w1
loss_mse = tf.reduce_mean(input_tensor = (Y_hat - Y)**2)

3. 设置优化器

使用梯度下降优化器,它会自动处理可训练变量的更新:

LEARNING_RATE = tf.placeholder(dtype = tf.float32, shape = None)
optimizer = tf.train.GradientDescentOptimizer(learning_rate = LEARNING_RATE).minimize(loss = loss_mse)

4. 训练循环

执行训练过程,定期输出损失值:

STEPS = 1000
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())  # 初始化变量
    
    for step in range(STEPS):
        sess.run(fetches = optimizer, feed_dict = {LEARNING_RATE: 0.02})
        
        if step % 100 == 0:
            print("STEP: {} MSE: {}".format(step, sess.run(fetches = loss_mse)))
    
    print("最终参数:")
    print("w0:{}".format(round(float(sess.run(w0)), 4)))
    print("w1:{}".format(round(float(sess.run(w1)), 4)))

关键概念总结

  1. 变量(Variable):用于存储模型参数,可训练且可变
  2. 占位符(Placeholder):用于运行时输入数据
  3. 会话(Session):连接前端Python和后端执行引擎的桥梁
  4. 优化器(Optimizer):自动计算梯度并更新变量

通过本文的学习,您应该已经掌握了TensorFlow图模式的基本使用方法,并能够实现简单的线性回归模型。图模式虽然需要更多的样板代码,但在性能优化和分布式训练方面具有明显优势,是TensorFlow在生产环境中的推荐使用方式。

training-data-analyst Labs and demos for courses for GCP Training (http://cloud.google.com/training). training-data-analyst 项目地址: https://gitcode.com/gh_mirrors/tr/training-data-analyst

创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

丁绮倩

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值