Answer
针对特定张量计算的loss,能且只能对与其直接相关的张量进行梯度计算与更新。
[着急直接看后续demo]
问题描述
- 模型能够跑通
- loss只能部分下降
问题分析
- 未能下降的loss为基于二范数构建的损失项
- 用于做差求二范数的两个张量中,有一个张量从文件加载,数值固定
- 显然,存在一直未能进行更新的张量
Demo
# 构建2*2的张量x1与x2
x1 = tf.get_variable(shape=[2, 2],initializer=tf.contrib.layers.xavier_initializer(uniform=False),dtype=tf.float64,name='x1')
x2 = tf.get_variable(shape=[2, 2],initializer=tf.contrib.layers.xavier_initializer(uniform=False),dtype=tf.float64,name='x2')
# 构建x3与x4,其中x3取自x1的第一行,x4取自x2的第一行
x3 = tf.nn.embedding_lookup(x1, 0)
x4 = tf.nn.embedding_lookup(x2, 0)
# 构建目标函数计算的第一项:w
w = tf.get_variable(shape=[2, 2],initializer=tf.contrib.layers.xavier_initializer(uniform=False),dtype=tf.float64,name='w')
# 构建目标函数计算的第二项:ww
ww = tf.placeholder(dtype=tf.float64, shape=[None, 2], name="ww")
# 构建目标函数
y = w * ww
# 针对梯度更新操作进行模拟
op = tf.train.AdamOptimizer(0.001) # 设定学习率
grad = op.compute_gradients(y) # 梯度计算
news = op.apply_gradients(grad) # 张量更新
至此,准备工作已经完成,下面首先展示无法完成梯度更新的demo:
with tf.Session() as sess:
# 变量初始化
sess.run(tf.global_variables_initializer())
# 模拟训练
for epoch in range(3):
print(f"Epoch {epoch + 1}")
# 因为ww的计算需要x3与x4作为输入,先执行得到x3与x4
tx3, tx4 = sess.run([x3, x4])
# 此处tx3与tx4均为arr,不再是Tensor,拼接为ww所需shape
txx = np.vstack((tx3, tx4))
# 输入txx,计算ww并完成梯度更新
ttx, _ = sess.run([ww, news], feed_dict={
ww: txx
})
# 输出x1与x2,查看梯度更新结果
cx1, cx2 = sess.run([x1, x2])
print(cx1[0][1], cx2[0][1], cx1[1][1], cx2[1][1])
print()
实验结果如下:
Epoch 1
-0.3160230179718143 -0.9542103563632182 0.07167841193652867 -0.638824844460277
Epoch 2
-0.3160230179718143 -0.9542103563632182 0.07167841193652867 -0.638824844460277
Epoch 3
-0.3160230179718143 -0.9542103563632182 0.07167841193652867 -0.638824844460277
下面展示能够完成梯度更新的demo:
# 首先将ww改写为
ww = x3 + x4
# 下面与上述Demo基本相同
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for epoch in range(3):
print(f"Epoch {epoch + 1}")
_ = sess.run(news)
cx1, cx2 = sess.run([x1, x2])
print(cx1[0][1], cx2[0][1], cx1[1][1], cx2[1][1])
print()
实验结果如下:
Epoch 1
-0.6685521525329547 -0.48460736297654045 -0.15471450503649956 0.3457514012500578
Epoch 2
-0.6695524480668291 -0.4856076585104148 -0.15471450503649956 0.3457514012500578
Epoch 3
-0.6705532353881327 -0.48660844583171836 -0.15471450503649956 0.3457514012500578
不难看出,通过x3更新了x1中的第一行,而第二行没有更新。