TensorFlow2.x 更新变量部分元素

背景

复现论文时,有的网络会有这种需求:有一个全局变量B,在网络训练的每个迭代的开始会用到它,每次迭代结束会更新它。

在训练的具体过程因为训练样本比较多,所以每次只能取batch_size个样本去训练,这样我就需要对B的特定位置进行更新。

代码

这里以一个简单的例子,介绍怎样对全局变量B更新部分元素

# -*- coding: gbk -*-
import tensorflow as tf


def main():
    m = 10
    n = 5
    # 设置全局变量
    B = tf.Variable(tf.random.normal([m, n]), dtype=tf.float32, name='B')

    # 假设在函数中需要更新全局变量B
    def update_B(index, new_value):
        B.assign(tf.tensor_scatter_nd_update(B, indices=tf.expand_dims(index, axis=1), updates=new_value))
        # # 因为index是一维列表,无法和B的维度进行匹配,所以要扩展一个维度

    # 设置索引 想对【1, 2】行进行更新
    index = tf.constant(list(range(1, 3)), dtype=tf.int32)
    # 新值
    new_B = tf.constant([[21., 22., 23., 24., 25.],
                         [31., 32., 33., 34., 35.]], dtype=tf.float32)
	# 下面显示B更新前后的前5行
    print('B:', B.numpy()[:5, :])
    update_B(index, new_B)
    print('B has been updated')
    print('B:', B.numpy()[:5, :])


main()

结果

在这里插入图片描述

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值