背景
复现论文时,有的网络会有这种需求:有一个全局变量B,在网络训练的每个迭代的开始会用到它,每次迭代结束会更新它。
在训练的具体过程因为训练样本比较多,所以每次只能取batch_size个样本去训练,这样我就需要对B的特定位置进行更新。
代码
这里以一个简单的例子,介绍怎样对全局变量B更新部分元素
# -*- coding: gbk -*-
import tensorflow as tf
def main():
m = 10
n = 5
# 设置全局变量
B = tf.Variable(tf.random.normal([m, n]), dtype=tf.float32, name='B')
# 假设在函数中需要更新全局变量B
def update_B(index, new_value):
B.assign(tf.tensor_scatter_nd_update(B, indices=tf.expand_dims(index, axis=1), updates=new_value))
# # 因为index是一维列表,无法和B的维度进行匹配,所以要扩展一个维度
# 设置索引 想对【1, 2】行进行更新
index = tf.constant(list(range(1, 3)), dtype=tf.int32)
# 新值
new_B = tf.constant([[21., 22., 23., 24., 25.],
[31., 32., 33., 34., 35.]], dtype=tf.float32)
# 下面显示B更新前后的前5行
print('B:', B.numpy()[:5, :])
update_B(index, new_B)
print('B has been updated')
print('B:', B.numpy()[:5, :])
main()