【Tensorflow】gather与gather_nd

tf1.x 类比切片

目录

tf.one_hot 提取行

gather 提取行 

gather 提取列

gather_nd


tf.one_hot 提取行

与lookup的作用差不多embedding层_tensorflow中的Embedding操作详解_weixin_39835321的博客-CSDN博客

import tensorflow as tf
import numpy as np

embedding1 = tf.constant(
    [
        [0.21,0.41,0.51,0.11],
        [0.22,0.42,0.52,0.12],
        [0.23,0.43,0.53,0.13],
        [0.24,0.44,0.54,0.14]
    ],dtype=tf.float32)
 
feature_batch = tf.constant([2,3,1,0])
# feature_batch
# <tf.Tensor 'Const_1:0' shape=(4,) dtype=int32>


feature_batch_one_hot = tf.one_hot(feature_batch, depth=4)
# feature_batch_one_hot
# <tf.Tensor 'one_hot:0' shape=(4, 4) dtype=float32>

get_embedding2 = tf.matmul(feature_batch_one_hot, embedding1)
# get_embedding2
# <tf.Tensor 'MatMul_2:0' shape=(4, 4) dtype=float32>

运行时,上面的遍历才会产生

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    embedding1, embedding2 = sess.run([embedding1, get_embedding2])
    print("embedding1 \n", embedding1)
    print("embedding2 \n", embedding2)


embedding1 
 [[0.21 0.41 0.51 0.11]
 [0.22 0.42 0.52 0.12]
 [0.23 0.43 0.53 0.13]
 [0.24 0.44 0.54 0.14]]


embedding2 
 [[0.23 0.43 0.53 0.13]
 [0.24 0.44 0.54 0.14]
 [0.22 0.42 0.52 0.12]
 [0.21 0.41 0.51 0.11]]
# 可以看出emb2是通过emb1 根据下标[2,3,1,0]进行了行的调整

lookup 

- embedding_lookup函数的作用更像是一个搜索操作,即根据我们提供的索引,从对应的tensor中寻找对应位置的切片。

- 是gather函数的一种特殊形式

gather 提取行 

# gather, axis=0 (行)
# 当 params是二维的tensor,轴axis=0时,跟我们讲的embedding_lookup函数等价

embedding = tf.constant(
    [
        [0.21,0.41,0.51,0.11],
        [0.22,0.42,0.52,0.12],
        [0.23,0.43,0.53,0.13],
        [0.24,0.44,0.54,0.14]
    ],dtype=tf.float32)
 
index_a = tf.Variable([2,3,1,0])
gather_a = tf.gather(embedding, index_a)

# index_a
# <tf.Variable 'Variable:0' shape=(4,) dtype=int32_ref>

# gather_a
# <tf.Tensor 'GatherV2:0' shape=(4, 4) dtype=float32>

可以看到gather也实现了如上tf.one_hot的效果

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(gather_a))

# 在emb中,按照下标[2,3,1,0]顺序,构成了如下矩阵

[[0.23 0.43 0.53 0.13]
 [0.24 0.44 0.54 0.14]
 [0.22 0.42 0.52 0.12]
 [0.21 0.41 0.51 0.11]]

gather 提取列

# gather, axis=1 (列)
embedding = tf.constant(
    [
        [0.21,0.41,0.51,0.11],
        [0.22,0.42,0.52,0.12],
        [0.23,0.43,0.53,0.13],
        [0.24,0.44,0.54,0.14]
    ],dtype=tf.float32)
 
gather_a_axis1 = tf.gather(embedding, index_a, axis=1)


with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(gather_a_axis1))


[[0.51 0.11 0.41 0.21]
 [0.52 0.12 0.42 0.22]
 [0.53 0.13 0.43 0.23]
 [0.54 0.14 0.44 0.24]]

当emb是一维时

# 当params是一维的tensor
b = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
index_b = tf.Variable([2, 4, 6, 8])
gather_b = tf.gather(b, index_b)
 
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(gather_b))


[3 5 7 9]

# b下标2的值是3
# b下标4的值是5

gather_nd

tf.gather函数呢,我们只能通过一个维度的来获取切片,如果我们想要通过多个维度的联合索引来获取切片,可以通过gather_nd函数。

tf.reset_default_graph()
a = tf.Variable([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10], [11, 12, 13, 14, 15]])
a
<tf.Variable 'Variable:0' shape=(3, 5) dtype=int32_ref>

index_a = tf.Variable([0, 2])
index_a
<tf.Variable 'Variable_1:0' shape=(2,) dtype=int32_ref>


b = tf.get_variable(name='b',shape=[3,3,2],initializer=tf.random_normal_initializer)
b
<tf.Variable 'b:0' shape=(3, 3, 2) dtype=float32_ref>


index_b = tf.Variable([[0,1,1],[2,2,0]])
index_b
<tf.Variable 'Variable_2:0' shape=(2, 3) dtype=int32_ref>

结果

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print('-'*10)
    # 找a中第0个list的第2个值,是3
    print(sess.run(tf.gather_nd(a, index_a))) # 3 
    print('-'*10)



    print(sess.run(b)) # 随机初始化b矩阵,维度3,3,2
    print('-'*10)
    # 找b矩阵中,
    # [0,1,1], 第0层,第1层,的第一个值
    # [2,2,0],第2层,第2层,的第0个值
    print(sess.run(tf.gather_nd(b, index_b)))




----------
3 
----------
[[[ 0.5142302  -0.05901795]
  [-0.04706477  0.08232412]
  [-0.00842589 -1.1469455 ]]

 [[-0.4118051  -0.87490994]
  [-1.5529685   0.5411136 ]
  [ 0.49881363  2.527228  ]]

 [[ 0.19706753 -1.9549321 ]
  [ 0.551086    1.064308  ]
  [ 0.22157238 -2.3003275 ]]]
----------
[0.08232412 0.22157238]

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值