前言:学习tf.tensor_scatter_add算子,理解算子的基本功能和在多维上的计算。
1.找到算子的定义
不知道这个是不是google的官网,链接地址:https://tensorflow.google.cn/versions/r2.1/api_docs/python/tf/tensor_scatter_nd_add
tf.tensor_scatter_nd_add和tf.tensor_scatter_add这两个算子的关系,好像tf.tensor_scatter_nd_add是tf.tensor_scatter_add的升级版吧,具体不清楚,我可能是错的。
tf.tensor_scatter_nd_add,tensorflow中文社区链接地址:http://www.tensorfly.cn/tfdoc/api_docs/tf/tensor_scatter_nd_add.html
W3Cschool链接地址:https://www.w3cschool.cn/tensorflow_python/tensorflow_python-h5yo2j65.html
在W3Cschool,我只能找到tf.scatter_nd_add算子,这个算子和tf.tensor_scatter_add算子功能基本上是一样的,就是多了use_locking属性。
tf.scatter_nd_add是tf.scatter_add在多维上的运用。在我理解上多维的话就能进行tensor之间直接的加减乘除,而不再是数字与数字之间的加减运算。
2.理解算子入参之间的shape关系
多维的情况是很复杂的,tf.tensor_scatter_add三个入参之间的shape存在一定的关系,算子的第一个入参input_x可以看作是原始的tensor,indices是将要对input_x中的某一维的tensor做处理的索引tensor,只能是二维的,一维会报错,updates是根据indices将要和input_x中indices位置上的tensor或标量做出计算的新tensor,就是算子的输出tensor的shape要和input_x的shape一样,该算子的功能可以看作是对input_x这个入参进行数据上的修改。
如果算子的三个入参的shape关系没有给对,就会报以下的错误:
(图片来源:https://stackoverflow.com/questions/55615900/tensorflow-tf-tensor-scatter-add-for-two-tensors-with-unknown-batch-size)
从最简单的一维开始,在网上找个例子看看是怎样计算的。
tf.scatter_sub 函数,链接地址:https://blog.csdn.net/weixin_41874599/article/details/82793658
tf.scatter_nd,链接地址:https://www.w3cschool.cn/tensorflow_python/tensorflow_python-led42j40.html
上面这两个规则很关键。
(1) indices的shape最后一维的取值应该小于等于input_x的秩,也就是input_x的shape的长度。indices本身要求至少是二维的。
(2)updates的shape是要根据input_x和indices的shape计算得到的,公式是固定的:updates.shape = indices.shape[:-1]+input_x.shape[indices.shape[-1]:]
在多维上,tf.scatter_nd, tf.scatter_nd_add, tf.scatter_nd_sub, tf.scatter_nd_update对于入参之间shape的运算关系是一样的。
tf.scatter_nd_update,链接地址:https://blog.csdn.net/DaVinciL/article/details/84027241
一维shape
ref.shape = [8]
indices.shape = [4,1]
按照公式updates.shape = indices.shape[:-1]+input_x.shape[indices.shape[-1]:]计算得到:
updates.shape = [4]
二维shape
ref.shape = [6,6]
indices.shape = [4,2]
updates.shape = [4]
3.难点是indices数据的生成
省略。。。
4.顺道学习相关算子
重点学习:tensor_scatter_sub/tensor_scatter_update
tf.tensor_scatter_nd_sub/tf.tensor_scatter_nd_update
tf.tensor_scatter_update的indices在同一维度的里面的数字是可以包含重复项的。
等等。。。