import tensorflow as tf
x = tf.random.normal([4,3])print(x)
a = tf.gather(x,[0,2],axis=0)print(a)
out:
tf.Tensor([[0.61712810.04664925-1.240589][0.3884052-0.13463594-0.93812966][0.210235950.02191296-0.60263956][-0.59707534-0.99890140.93808985]], shape=(4,3), dtype=float32)
tf.Tensor([[0.61712810.04664925-1.240589][0.210235950.02191296-0.60263956]], shape=(2,3), dtype=float32)
tf.gather_nd(x,index)
index 多个列表,指明元素位置
import tensorflow as tf
x = tf.random.normal([4,3,5])print(x)
a = tf.gather_nd(x,[[1,2,4],[2,0,4],[2,2,4]])print(a)
out:
tf.Tensor([[[0.48533428-0.46022970.107186991.46348231.0696448][-0.106463281.42041480.49611762-0.7021836-0.78922015][0.279589060.775796061.19432410.132998681.8589126]][[-1.2397368-0.86675161.2485082-0.36232617-0.01963608][-0.3774731-0.95083530.176645870.8245683-1.4487312][-0.402360021.13519060.281336-1.2188344-0.5182614]][[-1.39381091.2935148-0.8586935-1.6478568-0.7925164][-0.5961709-0.14946866-0.344935660.70288414-0.32704452][1.1950675-0.221098480.21102418-0.3180953-0.4069654]][[-2.664851-0.6875616-0.756242751.4966092-0.70387363][0.19424789-0.058907140.64623160.36867452-1.6294032][-0.2791568-1.48804620.288558240.183386431.1877102]]], shape=(4,3,5), dtype=float32)
tf.Tensor([-0.5182614-0.7925164-0.4069654], shape=(3,), dtype=float32)
掩码方式采样
tf.boolean_mask(x, mask, axis)
mask为布尔值列表
import tensorflow as tf
x = tf.random.normal([4,3,5])
b = tf.boolean_mask(x,[False,False,True,True],axis=0)print(b)
out:
tf.Tensor([[[0.99549340.2346662-1.5152481-0.239085331.0049196][-1.93558542.1986670.70910410.8390751-0.539098][-0.8521507-0.56876990.50747920.71542390.14445351]][[-2.0306711-0.02072303-1.4181378-0.01865017-0.26235464][1.1109879-1.454516-0.4335605-1.37627-0.7934608][-0.22541425-1.0131035-0.43860020.475430520.71290684]]], shape=(2,3,5), dtype=float32)
条件取样
tf.where(cond,a,b)
cond与a与b的shape相同
cond为布尔值张量
True选a对应位置元素,false选b对应位置元素
import tensorflow as tf
a = tf.random.normal([4,4])
b = tf.random.normal([4,4])
x = tf.where([[True,False,False,False],[True,False,True,False],[True,False,True,False],[True,False,True,False]],a,b)print(a,'\n',b,'\n',x)
out:
tf.Tensor([[0.528277-0.934028-1.822460.6171794][1.3643358-0.10889477-0.426288631.3473538][-0.8530917-1.13937641.1789635-0.47311786][-0.06923807-0.7026398-0.51892390.12308616]], shape=(4,4), dtype=float32)
tf.Tensor([[-0.11811212-0.86608011.0821908-0.80681664][0.7331498-0.31200057-0.7272013-0.04565765][0.05147053-0.398783481.404608-0.50364506][-0.48536187-0.80131745-0.728187861.5861374]], shape=(4,4), dtype=float32)
tf.Tensor([[0.528277-0.86608011.0821908-0.80681664][1.3643358-0.31200057-0.42628863-0.04565765][-0.8530917-0.398783481.1789635-0.50364506][-0.06923807-0.80131745-0.51892391.5861374]], shape=(4,4), dtype=float32)
刷新张量的部分数据
tf.scatter_nd(indices, updates, shape)
白板的形状通过 shape 参数表示,需要刷新的数据索引号通过 indices 表示,新数据为 updates。 根据 indices 给出的索引位置将 updates 中新的数据依次写入白板中,并返回更新后的结果张量。
import tensorflow as tf
indices = tf.constant([[2],[3]])
updata = tf.constant([5,4])
a = tf.scatter_nd(indices,updata,[10])print(a)
out:
tf.Tensor([0054000000], shape=(10,), dtype=int32)
生成二维网格的采样点坐标
tf.meshgrid(x,y)
将x,y的列表元素枚举法配对
import tensorflow as tf
x = tf.linspace(-7.,8,3)
y = tf.linspace(-8.,8,3)
x,y = tf.meshgrid(x,y)print(x)print(y)
out:
tf.Tensor([[-7.0.58.][-7.0.58.][-7.0.58.]], shape=(3,3), dtype=float32)
tf.Tensor([[-8.-8.-8.][0.0.0.][8.8.8.]], shape=(3,3), dtype=float32)