(12)tensorflow高级操作函数

高级操作函数

功能函数代码
根据索引号抽样tf.gather(x,index,axis)
根据索引号采集多个样本tf.gather_nd(x,index)
掩码采样tf.boolean_mask(x, mask, axis)
条件取样tf.where(cond,a,b)
刷新张量tf.scatter_nd(indices, updates, shape)
生成二维网格的采样点坐标tf.meshgrid (x,y)

根据索引号收集数据

  • tf.gather ,tf.gather(x,index,axis)
  • index使用列表传入
import tensorflow as tf
x = tf.random.normal([4,3])
print(x)
a = tf.gather(x,[0,2],axis=0)
print(a)

out:
tf.Tensor(
[[ 0.6171281   0.04664925 -1.240589  ]
 [ 0.3884052  -0.13463594 -0.93812966]
 [ 0.21023595  0.02191296 -0.60263956]
 [-0.59707534 -0.9989014   0.93808985]], shape=(4, 3), dtype=float32)
tf.Tensor(
[[ 0.6171281   0.04664925 -1.240589  ]
 [ 0.21023595  0.02191296 -0.60263956]], shape=(2, 3), dtype=float32)
  • tf.gather_nd(x,index)
  • index 多个列表,指明元素位置
import tensorflow as tf
x = tf.random.normal([4,3,5])
print(x)
a = tf.gather_nd(x,[[1,2,4],[2,0,4],[2,2,4]])
print(a)

out:
tf.Tensor(
[[[ 0.48533428 -0.4602297   0.10718699  1.4634823   1.0696448 ]
  [-0.10646328  1.4204148   0.49611762 -0.7021836  -0.78922015]
  [ 0.27958906  0.77579606  1.1943241   0.13299868  1.8589126 ]]

 [[-1.2397368  -0.8667516   1.2485082  -0.36232617 -0.01963608]
  [-0.3774731  -0.9508353   0.17664587  0.8245683  -1.4487312 ]
  [-0.40236002  1.1351906   0.281336   -1.2188344  -0.5182614 ]]

 [[-1.3938109   1.2935148  -0.8586935  -1.6478568  -0.7925164 ]
  [-0.5961709  -0.14946866 -0.34493566  0.70288414 -0.32704452]
  [ 1.1950675  -0.22109848  0.21102418 -0.3180953  -0.4069654 ]]

 [[-2.664851   -0.6875616  -0.75624275  1.4966092  -0.70387363]
  [ 0.19424789 -0.05890714  0.6462316   0.36867452 -1.6294032 ]
  [-0.2791568  -1.4880462   0.28855824  0.18338643  1.1877102 ]]], shape=(4, 3, 5), dtype=float32)
tf.Tensor([-0.5182614 -0.7925164 -0.4069654], shape=(3,), dtype=float32)

掩码方式采样

  • tf.boolean_mask(x, mask, axis)
  • mask为布尔值列表
import tensorflow as tf
x = tf.random.normal([4,3,5])
b = tf.boolean_mask(x,[False,False,True,True],axis=0)
print(b)

out:
tf.Tensor(
[[[ 0.9954934   0.2346662  -1.5152481  -0.23908533  1.0049196 ]
  [-1.9355854   2.198667    0.7091041   0.8390751  -0.539098  ]
  [-0.8521507  -0.5687699   0.5074792   0.7154239   0.14445351]]

 [[-2.0306711  -0.02072303 -1.4181378  -0.01865017 -0.26235464]
  [ 1.1109879  -1.454516   -0.4335605  -1.37627    -0.7934608 ]
  [-0.22541425 -1.0131035  -0.4386002   0.47543052  0.71290684]]], shape=(2, 3, 5), dtype=float32)

条件取样

  • tf.where(cond,a,b)
  • cond与a与b的shape相同
  • cond为布尔值张量
  • True选a对应位置元素,false选b对应位置元素
import tensorflow as tf
a = tf.random.normal([4,4])
b = tf.random.normal([4,4])
x = tf.where([[True,False,False,False],[True,False,True,False],[True,False,True,False],[True,False,True,False]],a,b)
print(a,'\n',b,'\n',x)

out:
tf.Tensor(
[[ 0.528277   -0.934028   -1.82246     0.6171794 ]
 [ 1.3643358  -0.10889477 -0.42628863  1.3473538 ]
 [-0.8530917  -1.1393764   1.1789635  -0.47311786]
 [-0.06923807 -0.7026398  -0.5189239   0.12308616]], shape=(4, 4), dtype=float32) 
 tf.Tensor(
[[-0.11811212 -0.8660801   1.0821908  -0.80681664]
 [ 0.7331498  -0.31200057 -0.7272013  -0.04565765]
 [ 0.05147053 -0.39878348  1.404608   -0.50364506]
 [-0.48536187 -0.80131745 -0.72818786  1.5861374 ]], shape=(4, 4), dtype=float32) 
 tf.Tensor(
[[ 0.528277   -0.8660801   1.0821908  -0.80681664]
 [ 1.3643358  -0.31200057 -0.42628863 -0.04565765]
 [-0.8530917  -0.39878348  1.1789635  -0.50364506]
 [-0.06923807 -0.80131745 -0.5189239   1.5861374 ]], shape=(4, 4), dtype=float32)

刷新张量的部分数据

  • tf.scatter_nd(indices, updates, shape)
  • 白板的形状通过 shape 参数表示,需要刷新的数据索引号通过 indices 表示,新数据为 updates。 根据 indices 给出的索引位置将 updates 中新的数据依次写入白板中,并返回更新后的结果张量。
import tensorflow as tf
indices = tf.constant([[2],[3]])
updata = tf.constant([5,4])
a = tf.scatter_nd(indices,updata,[10])
print(a)

out:
tf.Tensor([0 0 5 4 0 0 0 0 0 0], shape=(10,), dtype=int32)

生成二维网格的采样点坐标

  • tf.meshgrid(x,y)
  • 将x,y的列表元素枚举法配对
import tensorflow as tf
x = tf.linspace(-7.,8,3)
y = tf.linspace(-8.,8,3)
x,y = tf.meshgrid(x,y)
print(x)
print(y)

out:
tf.Tensor(
[[-7.   0.5  8. ]
 [-7.   0.5  8. ]
 [-7.   0.5  8. ]], shape=(3, 3), dtype=float32)
tf.Tensor(
[[-8. -8. -8.]
 [ 0.  0.  0.]
 [ 8.  8.  8.]], shape=(3, 3), dtype=float32)
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

小蜗笔记

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值