对于tf.nn.embedding_lookup()的使用

def embedding_lookup(
    params,
    ids,
    partition_strategy="mod",
    name=None,
    validate_indices=True,  
    max_norm=None):

主要的两个参数的理解:

params:一个tensor形式,一般为embedding后的shape,在进行分类中是[vocab_size, embedding_size]形式的值

ids:这个表示的是int数字形式的tensor,我们会根据这个ids进行对params里的值进行查询操作

所以总结来说就是:tf.nn.embedding_lookup()就是根据input_ids中的id,寻找embeddings中的第id行。比如input_ids=[2,8,10],则找出embeddings中第2,8,10行,组成一个tensor返回。

下面给出一段代码示例:

import tensorflow as tf
import numpy as np

a = np.random.rand(5,3)
b = np.array([[1,2],[3,1]])
print(a)
print("*" * 10 + "上面是a的输出结果" + "*" * 10)
print(b)
print("*" * 10 + "上面是b的输出结果" + "*" * 10)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
c = tf.nn.embedding_lookup(a ,b)
print(sess.run(c))
print("*" * 10 + "上面是c的输出结果" + "*" * 10)

下面是输出的结果,从结果中我们可以很明显看到通过b的给出的ids值查找到了a中对应索引位置的值

[[0.84679692 0.06116903 0.12382439]
 [0.49653969 0.74192834 0.07095517]
 [0.6960441  0.10387642 0.23318349]
 [0.43168901 0.5049978  0.83005329]
 [0.7418392  0.54963284 0.03186683]]
**********上面是a的输出结果**********
[[1 2]
 [3 1]]
**********上面是b的输出结果**********
[[[0.49653969 0.74192834 0.07095517]
  [0.6960441  0.10387642 0.23318349]]

 [[0.43168901 0.5049978  0.83005329]
  [0.49653969 0.74192834 0.07095517]]]
**********上面是c的输出结果**********

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

liu_sir_

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值