import tensorflow as tf
import numpy as np
tf.enable_eager_execution()
emb = [
[1, 1, 1],
[2, 2, 2],
[3, 3, 3],
]
emb = np.array(emb)
ids = [
[1, 1],
[2, 2],
]
ids = np.array(ids)
lookup = tf.nn.embedding_lookup(emb, ids=ids)
print(lookup)
# 即对于shape = [2,2] 其最后的输出shape=[2,2,3]
tf.Tensor(
[[[2 2 2]
[2 2 2]]
[[3 3 3]
[3 3 3]]], shape=(2, 2, 3), dtype=int32)
tf.nn.embedding_lookup
最新推荐文章于 2022-01-17 16:12:01 发布