【Tensorflow2.x学习笔记】tf.GradientTape自动求梯度

tf.constant()函数

作用: 创建tensor常量
函数形式: tf.constant(value, shape, dtype=None, name=None)
参数释义: value:值,shape:数据形状,dtype:数据类型,name:名称

tf.GradientTape()函数

作用: 用于计算函数梯度,配合with as结构使用
函数形式: 使用__init__函数初始化对象,__enter__函数和__exit__函数
配合使用实现上下文管理器(用于连接需要计算梯度的函数与变量)

with as结构

作用: with as可用于简化try finally代码,与下述形式的try finally等价

try:  
    执行 __enter__的内容  
    执行 with_block.  
finally:  
    执行 __exit__内容 

执行过程: with expression as variable的执行过程是首先执行__enter__函数,它的返回值会赋值给as后面的variable,然后执行with-block中的语句,不论发生什么with-block执行后都执行__exit__函数

tape.watch()函数

作用: 确保某个tensor被tape追踪
函数形式: watch(tensor)
参数释义: tensor: 一个Tensor或者一个Tensor列表

注意:watch函数把需要计算梯度的变量加入。GradientTape默认只监控由 tf.Variable 创建的traiable=True属性(默认)的变量。若变量是constant,则计算梯度 需要增加 tape.watch([a, b, c])函数。当然,也可以设置不自动监控可训练变量,完全由自己指定,设置watch_accessed_variables=False就行了(一般用不到)。

tape.gradient()函数

作用: 根据tape上面的上下文来计算某个或者某些tensor的梯度
函数形式: gradient(target,sources,output_gradients=None,unconnected_gradients=tf.UnconnectedGradients.NONE)
参数释义: target:需要求导的目标函数方程、sources:被求导的一个Tensor或者Tensor列表

代码实践

# -*- coding : utf-8 -*-            
# @Time : 2022/3/4 23:02
# @Author : SXQ
# @FileName : autograd

import tensorflow as tf

# constant函数用于生成tensor常量
x = tf.constant(2.)
a = tf.constant(2.)
b = tf.constant(3.)
c = tf.constant(4.)

# with可用于简化try finally代码
# with expression as variable的执行过程是,首先执行__enter__函数
# 它的返回值会赋值给as后面的variable
# 然后执行with-block中的语句,不论发生什么with-block执行后都执行__exit__函数
with tf.GradientTape() as tape:
    # 确保某个tensor被tape追踪
    tape.watch([a, b, c])
    # 函数公式
    y = a ** 2 * x + b * x + c
# gradient函数根据tape上面的上下文来计算某个或者某些tensor的梯度
[dy_da, dy_db, dy_dc] = tape.gradient(y, [a, b, c])
print(dy_da, dy_db, dy_dc)
  • 1
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值