【三】TF接口-embedding_lookup

embedding_lookup是什么?

tf.nn.embedding_lookup函数的用法主要是选取一个张量里面索引对应的元素。embedding就是将输入文本表达成向量形式。向量化后需要用索引来查询对应向量,embedding_lookup就是帮助开发者来完成索引向量查询的。

为什么用embedding_lookup?

tf框架里,它会为输入的张量自动建立one-hot索引,但建立好的索引该如何与之后embedding向量对应起来并查询呢?这就需要通过索引-向量mapping表中去拿,此时,embedding_lookup就会帮助你完成这个操作。其实embedding_lookup本质是做了一次常规的线性变换,Z = WX + b。相当于通过one-hot的Weight矩阵,帮助使用者取出了矩阵中对应的那一行。相当于变相进行了一次矩阵相乘运算。看起来像查表一样。

什么时候使用embedding_lookup?

通常在训练之初,开始进行embedding的时候。看一下embedding_lookup的定义。

tf.nn.embedding_lookup(params, ids, partition_strategy='mod', name=None, validate_indices=True, max_norm=None)

TensorFlow官方文档定义见这里
其中params输入整型矩阵,用于给出索引embedding的idx。注意输出结果的shape其实相当于[params.shape(), embedding_shape()]

怎么使用embedding_lookup?

直接使用代码更容易说明:

import tensorflow as tf
import numpy as np

#a = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 2.2, 2.3], [3.1, 3.2, 3.3], [4.1, 4.2, 4.3]]
a = tf.Variable(np.identity(6, dtype=np.int32)) #np.asarray(a)
idx1 = tf.Variable([0, 2, 3, 1], tf.int32)
idx2 = tf.Variable([[0, 2, 3], [0, 2, 2]], tf.int32)
out1 = tf.nn.embedding_lookup(a, idx1)
out2 = tf.nn.embedding_lookup(a, idx2)
init = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init)
    print(sess.run(out1))
    print(out1)
    print('=' * 30)
    print(sess.run(out2))
    print(out2)

输出如下:

[[1 0 0 0 0 0]
 [0 0 1 0 0 0]
 [0 0 0 1 0 0]
 [0 1 0 0 0 0]]
Tensor("embedding_lookup/Identity:0", shape=(4, 6), dtype=int32)
==============================
[[[1 0 0 0 0 0]
  [0 0 1 0 0 0]
  [0 0 0 1 0 0]]

 [[1 0 0 0 0 0]
  [0 0 1 0 0 0]
  [0 0 1 0 0 0]]]
Tensor("embedding_lookup_1/Identity:0", shape=(2, 3, 6), dtype=int32)

之后,我们的out1和out2其实就可以作为训练的输入向量进行训练了。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值