tf.scatter_update tf.scatter_sub

tf.scatter_update

scatter_update(
ref,
indices,
updates,
use_locking=None,
name=None
)

scatter_sub(
ref,
indices,
updates,
use_locking=None,
name=None
)
在源码,函数的定义的位置在 tensorflow/Python/ops/gen_state_ops.py.
参数介绍:
ref: 原来的tensor;
indices: 原来tensor中要更新的索引值,同样也 tensor; 必须int
updates: 用于替代原来tensor的tensor值,注意,这个tensor和原来的tensor的最低维度要相同。

import tensorflow as tf 
import numpy as np 


with tf.Session() as sess1:


    c = tf.Variable([[1,2,0],[2,3,4]], dtype=tf.float32, name='biases') 
    cc = tf.Variable([[1,2,0],[2,3,4]], dtype=tf.float32, name='biases1') 
    ccc = tf.Variable([0,1], dtype=tf.int32, name='biases2') 

    #对应label的centers-diff[0--]
    centers = tf.scatter_sub(c,ccc,cc)
    #centers = tf.scatter_sub(c,[0,1],cc)  
    #centers = tf.scatter_sub(c,[0,1],[[1,2,0],[2,3,4]])
    #centers = tf.scatter_sub(c,[0,0,0],[[1,2,0],[2,3,4],[1,1,1]])
    #即c[0]-[1,2,0] \ c[0]-[2,3,4]\ c[0]-[1,1,1],updates要减完:indices与updates元素个数相同

    a = tf.Variable(initial_value=[[0, 0, 0, 0],[0, 0, 0, 0]])  
    b = tf.scatter_update(a, [0, 1], [[1, 1, 0, 0], [1, 0, 4, 0]])  
    #b = tf.scatter_update(a, [0, 1,0], [[1, 1, 0, 0], [1, 0, 4, 0],[1, 1, 0, 1]]) 

    init = tf.global_variables_initializer() 
    sess1.run(init)

    print(sess1.run(centers))
    print(sess1.run(b))


[[ 0.  0.  0.]
 [ 0.  0.  0.]]
[[1 1 0 0]
 [1 0 4 0]]


[[-3. -4. -5.]
 [ 2.  3.  4.]]
[[1 1 0 1]
 [1 0 4 0]]

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值