tensorflow中tf.nn.embedding_lookup()用法示例

当查找对象是二维张量的时候

代码如下:

import tensorflow as tf
import numpy as np

input_ids = tf.placeholder(tf.int32, shape=[None], name="input_ids")
embedding = tf.Variable(np.identity(5, dtype=np.int32))
input_embedding = tf.nn.embedding_lookup(embedding, input_ids)

sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
print("embedding=\n", embedding.eval())
print("input_embedding=\n", sess.run(input_embedding, feed_dict={input_ids: [1, 2, 3, 0, 3, 2, 1]}))


输出:

embedding=
 [[1 0 0 0 0]
 [0 1 0 0 0]
 [0 0 1 0 0]
 [0 0 0 1 0]
 [0 0 0 0 1]]
input_embedding=
 [[0 1 0 0 0]
 [0 0 1 0 0]
 [0 0 0 1 0]
 [1 0 0 0 0]
 [0 0 0 1 0]
 [0 0 1 0 0]
 [0 1 0 0 0]]
[Finished in 3.8s]

当查找索引时二维的时候

代码如下:

import tensorflow as tf
import numpy as np

input_ids = tf.placeholder(dtype=tf.int32, shape=[3, 2])
embedding = tf.Variable(np.identity(5, dtype=np.int32))
input_embedding = tf.nn.embedding_lookup(embedding, input_ids)
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())

print("embedding=\n", embedding.eval())
print("input_embedding=\n", sess.run(input_embedding, feed_dict={input_ids: [[1, 2], [2, 1], [3, 3]]}))


输出如下:

embedding=
 [[1 0 0 0 0]
 [0 1 0 0 0]
 [0 0 1 0 0]
 [0 0 0 1 0]
 [0 0 0 0 1]]
input_embedding=
 [[[0 1 0 0 0]
  [0 0 1 0 0]]

 [[0 0 1 0 0]
  [0 1 0 0 0]]

 [[0 0 0 1 0]
  [0 0 0 1 0]]]
[Finished in 4.0s]

当查找对象是三维的时候

自己简单试验了一下,观看代码结果方便理解:

import tensorflow as tf
import numpy as np

input_ids = tf.placeholder(tf.int32, shape=[None], name="input_ids")
a=[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24]
aa=tf.reshape(a,[4,2,3])
embedding = tf.Variable(aa)
input_embedding = tf.nn.embedding_lookup(embedding, input_ids)


sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
print("embedding=\n", embedding.eval())
print("input_embedding=\n", sess.run(input_embedding, feed_dict={input_ids: [1, 2]}))

输出结果如下:

embedding=
 [[[ 1  2  3]
  [ 4  5  6]]

 [[ 7  8  9]
  [10 11 12]]

 [[13 14 15]
  [16 17 18]]

 [[19 20 21]
  [22 23 24]]]
input_embedding=
 [[[ 7  8  9]
  [10 11 12]]
 
 [[13 14 15]
  [16 17 18]]]

当查找索引是二维的时候

代码如下:

import tensorflow as tf
import numpy as np
input_ids = tf.placeholder(tf.int32, shape=[23], name="input_ids")
a=[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24]
aa=tf.reshape(a,[4,2,3])
embedding = tf.Variable(aa)
input_embedding = tf.nn.embedding_lookup(embedding, input_ids)

sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
print("input_embedding=\n", sess.run(input_embedding, feed_dict={input_ids: [[1,2,3], [3,2,1]]}))

输出结果如下,省略了原始数据:

在这里插入代码片

在这里插入图片描述

AMHEN算法运行example结果图

在这里插入图片描述
epoch 9: 23%|| 1606/7066 [00:10<00:35, 155.83it/s]
加粗的字体字体是在运行过程中不断改变的。tqdm的动态可视化功能。

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值