embedding_lookup( )的用法
关于tensorflow中embedding_lookup( )的用法,在Udacity的word2vec会涉及到,本文将通俗的进行解释。
首先看一段网上的简单代码:
#!/usr/bin/env/python
# coding=utf-8
import tensorflow as tf
import numpy as np
input_ids = tf.placeholder(dtype=tf.int32, shape=[None])
embedding = tf.Variable(np.identity(5, dtype=np.int32))
input_embedding = tf.nn.embedding_lookup(embedding, input_ids)
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
print(embedding.eval())
print(sess.run(input_embedding, feed_dict={input_ids:[1, 2, 3, 0, 3, 2, 1]}))
代码中先使用palceholder定义了一个未知变量input_ids用于存储索引,和一个已知变量embedding,是一个5*5的对角矩阵。
运行结果为:
embedding = [[1 0 0 0 0]
[0 1 0 0 0]
[0 0 1 0 0]
[0 0 0 1 0]
[0 0 0 0 1]]
input_embedding = [[0 1 0 0 0]
[0 0 1 0 0]
[0 0 0 1 0]
[1 0 0 0 0]
[0 0 0 1 0]
[0 0 1 0 0]
[0 1 0 0 0]]
简单的讲就是根据input_ids中的id,寻找embedding中的对应元素。比如,input_ids=[1,3,5],则找出embedding中下标为1,3,5的向量组成一个矩阵返回。
如果将input_ids改写成下面的格式:
input_embedding = tf.nn.embedding_lookup(embedding, input_ids)
print(sess.run(input_embedding, feed_dict={input_ids:[[1, 2], [2, 1], [3, 3]]}))
输出结果就会变成如下的格式:
[[[0 1 0 0 0]
[0 0 1 0 0]]
[[0 0 1 0 0]
[0 1 0 0 0]]
[[0 0 0 1 0]
[0 0 0 1 0]]]
对比上下两个结果不难发现,相当于在np.array中直接采用下标数组获取数据。需要注意的细节是返回的tensor的dtype和传入的被查询的tensor的dtype保持一致;和ids的dtype无关。