tf.scatter_update
scatter_update(
ref,
indices,
updates,
use_locking=None,
name=None
)
scatter_sub(
ref,
indices,
updates,
use_locking=None,
name=None
)
在源码,函数的定义的位置在 tensorflow/Python/ops/gen_state_ops.py.
参数介绍:
ref: 原来的tensor;
indices: 原来tensor中要更新的索引值,同样也 tensor; 必须int
updates: 用于替代原来tensor的tensor值,注意,这个tensor和原来的tensor的最低维度要相同。
import tensorflow as tf
import numpy as np
with tf.Session() as sess1:
c = tf.Variable([[1,2,0],[2,3,4]], dtype=tf.float32, name='biases')
cc = tf.Variable([[1,2,0],[2,3,4]], dtype=tf.float32, name='biases1')
ccc = tf.Variable([0,1], dtype=tf.int32, name='biases2')
#对应label的centers-diff[0--]
centers = tf.scatter_sub(c,ccc,cc)
#centers = tf.scatter_sub(c,[0,1],cc)
#centers = tf.scatter_sub(c,[0,1],[[1,2,0],[2,3,4]])
#centers = tf.scatter_sub(c,[0,0,0],[[1,2,0],[2,3,4],[1,1,1]])
#即c[0]-[1,2,0] \ c[0]-[2,3,4]\ c[0]-[1,1,1],updates要减完:indices与updates元素个数相同
a = tf.Variable(initial_value=[[0, 0, 0, 0],[0, 0, 0, 0]])
b = tf.scatter_update(a, [0, 1], [[1, 1, 0, 0], [1, 0, 4, 0]])
#b = tf.scatter_update(a, [0, 1,0], [[1, 1, 0, 0], [1, 0, 4, 0],[1, 1, 0, 1]])
init = tf.global_variables_initializer()
sess1.run(init)
print(sess1.run(centers))
print(sess1.run(b))
[[ 0. 0. 0.]
[ 0. 0. 0.]]
[[1 1 0 0]
[1 0 4 0]]
[[-3. -4. -5.]
[ 2. 3. 4.]]
[[1 1 0 1]
[1 0 4 0]]