tf.nn.embedding_lookup

tf.nn.embedding_lookup(params, ids)
params: 一个张量或者数组
ids: 一个整型列表或一个二维矩阵,当输入为二维矩阵的时候,在CNN的时候会用,批量输入的时候ids为二维矩阵

该函数的目的是从params矩阵中返回行索引=ids中的元素的行向量组成矩阵
ids输入为一维列表的时候

import tensorflow as tf

table = tf.Variable(tf.random_normal([10, 5]))
b = tf.nn.embedding_lookup(table, [1, 4, 6, 7])
with tf.Session()as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(table))
    print(sess.run(b))

输出

[[ 0.2844954   1.0876138   0.2640958  -1.3939503   1.9493129 ]
 [-0.34022513 -0.22206968  0.19959041 -0.43038854  0.7214721 ]
 [ 1.2583389  -0.41636813  0.5526711  -0.04547537 -2.220672  ]
 [ 0.6416701  -0.04626859 -1.2670921  -1.0875092  -1.1969252 ]
 [-0.9369289   0.01590852 -1.0708148  -1.0230598   0.6950529 ]
 [-1.109506    0.43983954  1.1148814   0.48612115 -0.22546312]
 [ 0.7978611  -0.32981223  0.9465104   0.11148026 -0.8291709 ]
 [ 1.7482463  -0.84183437 -0.5938833   1.2219574   1.6940571 ]
 [ 0.3316857  -0.0637491   1.3450751   1.5049508  -0.66448265]
 [-0.56729424 -0.5770627   1.1358143   0.52266353 -2.49519   ]]
[[-0.34022513 -0.22206968  0.19959041 -0.43038854  0.7214721 ]
 [-0.9369289   0.01590852 -1.0708148  -1.0230598   0.6950529 ]
 [ 0.7978611  -0.32981223  0.9465104   0.11148026 -0.8291709 ]
 [ 1.7482463  -0.84183437 -0.5938833   1.2219574   1.6940571 ]]

ids为二维矩阵

import tensorflow as tf

table = tf.Variable(tf.random_normal([10, 5]))
b = tf.nn.embedding_lookup(table, [[1, 4, 6, 7],[1,4,2,5]])
with tf.Session()as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(table))
    print(sess.run(b))
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值