【tf算子学习】tensor_scatter_add

前言:学习tf.tensor_scatter_add算子,理解算子的基本功能和在多维上的计算。

1.找到算子的定义

不知道这个是不是google的官网,链接地址:https://tensorflow.google.cn/versions/r2.1/api_docs/python/tf/tensor_scatter_nd_add
在这里插入图片描述
tf.tensor_scatter_nd_add和tf.tensor_scatter_add这两个算子的关系,好像tf.tensor_scatter_nd_add是tf.tensor_scatter_add的升级版吧,具体不清楚,我可能是错的。

tf.tensor_scatter_nd_add,tensorflow中文社区链接地址:http://www.tensorfly.cn/tfdoc/api_docs/tf/tensor_scatter_nd_add.html
在这里插入图片描述

W3Cschool链接地址:https://www.w3cschool.cn/tensorflow_python/tensorflow_python-h5yo2j65.html
在这里插入图片描述
在W3Cschool,我只能找到tf.scatter_nd_add算子,这个算子和tf.tensor_scatter_add算子功能基本上是一样的,就是多了use_locking属性。
在这里插入图片描述
tf.scatter_nd_add是tf.scatter_add在多维上的运用。在我理解上多维的话就能进行tensor之间直接的加减乘除,而不再是数字与数字之间的加减运算。

2.理解算子入参之间的shape关系

多维的情况是很复杂的,tf.tensor_scatter_add三个入参之间的shape存在一定的关系,算子的第一个入参input_x可以看作是原始的tensor,indices是将要对input_x中的某一维的tensor做处理的索引tensor,只能是二维的,一维会报错,updates是根据indices将要和input_x中indices位置上的tensor或标量做出计算的新tensor,就是算子的输出tensor的shape要和input_x的shape一样,该算子的功能可以看作是对input_x这个入参进行数据上的修改。

如果算子的三个入参的shape关系没有给对,就会报以下的错误:
在这里插入图片描述
(图片来源:https://stackoverflow.com/questions/55615900/tensorflow-tf-tensor-scatter-add-for-two-tensors-with-unknown-batch-size

从最简单的一维开始,在网上找个例子看看是怎样计算的。
tf.scatter_sub 函数,链接地址:https://blog.csdn.net/weixin_41874599/article/details/82793658
在这里插入图片描述
tf.scatter_nd,链接地址:https://www.w3cschool.cn/tensorflow_python/tensorflow_python-led42j40.html
在这里插入图片描述
上面这两个规则很关键。
(1) indices的shape最后一维的取值应该小于等于input_x的秩,也就是input_x的shape的长度。indices本身要求至少是二维的。
在这里插入图片描述
(2)updates的shape是要根据input_x和indices的shape计算得到的,公式是固定的:updates.shape = indices.shape[:-1]+input_x.shape[indices.shape[-1]:]

在多维上,tf.scatter_nd, tf.scatter_nd_add, tf.scatter_nd_sub, tf.scatter_nd_update对于入参之间shape的运算关系是一样的。

tf.scatter_nd_update,链接地址:https://blog.csdn.net/DaVinciL/article/details/84027241

一维shape
在这里插入图片描述
ref.shape = [8]
indices.shape = [4,1]
按照公式updates.shape = indices.shape[:-1]+input_x.shape[indices.shape[-1]:]计算得到:
updates.shape = [4]

二维shape
在这里插入图片描述
ref.shape = [6,6]
indices.shape = [4,2]
updates.shape = [4]

3.难点是indices数据的生成

省略。。。

4.顺道学习相关算子

重点学习:tensor_scatter_sub/tensor_scatter_update
tf.tensor_scatter_nd_sub/tf.tensor_scatter_nd_update
在这里插入图片描述
在这里插入图片描述
tf.tensor_scatter_update的indices在同一维度的里面的数字是可以包含重复项的。
在这里插入图片描述
等等。。。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Logintern09

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值