tf.get_variable在训练过程中不更新——tensorflow变量 梯度更新


Answer

针对特定张量计算的loss,能且只能对与其直接相关的张量进行梯度计算与更新。
[着急直接看后续demo]


问题描述

  1. 模型能够跑通
  2. loss只能部分下降

问题分析

  1. 未能下降的loss为基于二范数构建的损失项
  2. 用于做差求二范数的两个张量中,有一个张量从文件加载,数值固定
  3. 显然,存在一直未能进行更新的张量

Demo

# 构建2*2的张量x1与x2
x1 = tf.get_variable(shape=[2, 2],initializer=tf.contrib.layers.xavier_initializer(uniform=False),dtype=tf.float64,name='x1')
x2 = tf.get_variable(shape=[2, 2],initializer=tf.contrib.layers.xavier_initializer(uniform=False),dtype=tf.float64,name='x2')
# 构建x3与x4,其中x3取自x1的第一行,x4取自x2的第一行
x3 = tf.nn.embedding_lookup(x1, 0)
x4 = tf.nn.embedding_lookup(x2, 0)
# 构建目标函数计算的第一项:w
w = tf.get_variable(shape=[2, 2],initializer=tf.contrib.layers.xavier_initializer(uniform=False),dtype=tf.float64,name='w')
# 构建目标函数计算的第二项:ww
ww = tf.placeholder(dtype=tf.float64, shape=[None, 2], name="ww")
# 构建目标函数
y = w * ww
# 针对梯度更新操作进行模拟
op = tf.train.AdamOptimizer(0.001) # 设定学习率
grad = op.compute_gradients(y) # 梯度计算
news = op.apply_gradients(grad) # 张量更新

至此,准备工作已经完成,下面首先展示无法完成梯度更新的demo:

with tf.Session() as sess:
	# 变量初始化
    sess.run(tf.global_variables_initializer())
    # 模拟训练
    for epoch in range(3):
        print(f"Epoch {epoch + 1}")
        # 因为ww的计算需要x3与x4作为输入,先执行得到x3与x4
        tx3, tx4 = sess.run([x3, x4])
        # 此处tx3与tx4均为arr,不再是Tensor,拼接为ww所需shape
        txx = np.vstack((tx3, tx4))
        # 输入txx,计算ww并完成梯度更新
        ttx, _ = sess.run([ww, news], feed_dict={
            ww: txx
        })
        # 输出x1与x2,查看梯度更新结果
        cx1, cx2 = sess.run([x1, x2])
        print(cx1[0][1], cx2[0][1], cx1[1][1], cx2[1][1])
        print()

实验结果如下:

Epoch 1
-0.3160230179718143 -0.9542103563632182 0.07167841193652867 -0.638824844460277

Epoch 2
-0.3160230179718143 -0.9542103563632182 0.07167841193652867 -0.638824844460277

Epoch 3
-0.3160230179718143 -0.9542103563632182 0.07167841193652867 -0.638824844460277

下面展示能够完成梯度更新的demo:

# 首先将ww改写为
ww = x3 + x4
# 下面与上述Demo基本相同
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for epoch in range(3):
        print(f"Epoch {epoch + 1}")
        _ = sess.run(news)
        cx1, cx2 = sess.run([x1, x2])
        print(cx1[0][1], cx2[0][1], cx1[1][1], cx2[1][1])
        print()

实验结果如下:

Epoch 1
-0.6685521525329547 -0.48460736297654045 -0.15471450503649956 0.3457514012500578

Epoch 2
-0.6695524480668291 -0.4856076585104148 -0.15471450503649956 0.3457514012500578

Epoch 3
-0.6705532353881327 -0.48660844583171836 -0.15471450503649956 0.3457514012500578

不难看出,通过x3更新了x1中的第一行,而第二行没有更新。


  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值