本文使用keras框架从侧面阐述
阅读本文的前提是对keras框架有一定的了解
keras中有个Embedding层,查看其源代码,核心代码为
def call(self, inputs):
if K.dtype(inputs) != 'int32':
inputs = K.cast(inputs, 'int32')
out = K.gather(self.embeddings, inputs)
return out
从上面的代码可以看出,我们需要去查看其后端函数gather
def gather(reference, indices):
"""Retrieves the elements of indices `indices` in the tensor `reference`.
# Arguments
reference: A tensor.
indices: An integer tensor of indices.
# Returns
A tensor of same type as `reference`.
"""
return tf.nn.embedding_lookup(reference, indices)
见到这个需要说明的函数了,写个简单的例子测试下
例子
from keras import backend as K
weights = K.constant([[1,2], [3,4], [5,6]], dtype='int32')
segment = [[1,2,1,0], [2,2,1,0]]
sess = K.get_session()
print( sess.run( K.gather( weights, segment ) ) )
sess.close()
weights张量可以看成词向量,segment可以看成一个batch的句子,显然
tf.nn.embedding_lookup
的功能就是在进行查表
操作。