tf.nn.embedding_lookup()函数

总是搞不清楚经过这个函数以后,张量的尺寸变化,现在进行一下总结。

tf.nn.embedding_lookup(
               params,
               ids,
               partition_strategy='mod',
               name=None,
               validate_indices=True,
               max_norm=None
)
参数说明:

params: 表示一个完整的张量。
ids: 一个类型为int32或int64的Tensor,包含要在params中查找的id。
partition_strategy: 指定分区策略的字符串,如果len(params)> 1,params的元素分割不能整分的话,则前(max_id + 1) % len(params)多分一个id。当前支持“div”和“mod”。 默认为“mod”。“mod”的划分形式为,比如9个ids划分为5个分区,则:[[0,5],[1,6],[2,7],[3,8],[4]],“div”的划分形式为比如9个ids分为5个分区,则:[[0,1],[2,3],[4,5],[6,7],[8]]
name: 操作名称(可选)
validate_indices: 是否验证收集索引
max_norm: 如果不是None,嵌入值将被l2归一化为max_norm的值。

代码展示
import tensorflow as tf

a = tf.Variable(tf.random_normal([5,4]))
b = [1,3]
c = tf.nn.embedding_lookup(a,b)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
print(sess.run(a))
print()
print(sess.run(b))

结果为
[[-0.63047475 0.26620588 0.60293007 -1.4683559 ]
[ 0.89961815 -0.6454514 1.671878 0.84864974]
[-1.7189969 0.5778266 0.0777434 0.8210027 ]
[-0.4449107 -1.0275484 1.1619903 -1.401834 ]
[ 0.9054293 0.6113461 0.4631226 -0.01497989]]

[[ 0.89961815 -0.6454514 1.671878 0.84864974]
[-0.4449107 -1.0275484 1.1619903 -1.401834 ]]

可以看到,通过用tf.nn.embedding_lookup()这个函数,根据b这个列表作为id指示,从a中找出相应id对应的数据。比如b=[1,3],则代表着从a中找出第1行和第3行的数据。

若b也为一个tensor,则其代码为:

import tensorflow as tf

a = tf.Variable(tf.random_normal([5,4]))
b = tf.Variable([[1,4,2,3,0],[3,1,0,1,4],[1,2,4,0,2]],tf.int32)
c = tf.nn.embedding_lookup(a,b)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
print(sess.run(a))
print()
print(sess.run(c))

结果为:
[[-0.11145564 0.06328291 -0.79501045 0.6656721 ]
[ 0.2718281 -2.1766105 -1.2639998 -0.46333247]
[-0.8792849 0.8736434 0.14076944 0.01519587]
[-0.18329176 -0.41718873 1.7577217 -1.4243344 ]
[-1.2562249 -0.97292686 1.0375485 1.3561603 ]]

[[[ 0.2718281 -2.1766105 -1.2639998 -0.46333247]
[-1.2562249 -0.97292686 1.0375485 1.3561603 ]
[-0.8792849 0.8736434 0.14076944 0.01519587]
[-0.18329176 -0.41718873 1.7577217 -1.4243344 ]
[-0.11145564 0.06328291 -0.79501045 0.6656721 ]]

[[-0.18329176 -0.41718873 1.7577217 -1.4243344 ]
[ 0.2718281 -2.1766105 -1.2639998 -0.46333247]
[-0.11145564 0.06328291 -0.79501045 0.6656721 ]
[ 0.2718281 -2.1766105 -1.2639998 -0.46333247]
[-1.2562249 -0.97292686 1.0375485 1.3561603 ]]

[[ 0.2718281 -2.1766105 -1.2639998 -0.46333247]
[-0.8792849 0.8736434 0.14076944 0.01519587]
[-1.2562249 -0.97292686 1.0375485 1.3561603 ]
[-0.11145564 0.06328291 -0.79501045 0.6656721 ]
[-0.8792849 0.8736434 0.14076944 0.01519587]]]

如果b也是一个tensor,则代表从a中按b每一行所指示的索引输出,比如上述代码中,b[0]为[1,4,2,3,0],则代表从a中输出第1,4,2,3,0行数据作为c的第一组数据,上述代码中,b的大小为3x5,说明要从a中抽出三组数据,每组数据的行数为5,a的大小为5x4,代表b从a中抽出的每一组数据都为4列,所以c的大小就变成了3x5x4.

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值