最近看GraphSage的代码,发现了一个函数tf.nn.fixed_unigram_candidate_sampler
理解这个函数的难点主要在: num_true、range_max, 这两个参数上,直接上代码吧
import tensorflow as tf
def test1():
vec = tf.constant([[1, 2, 3, 4, 6]], dtype=tf.int64)
# vec = tf.reshape(vec, [-1, 1])
"""
如果这里用了reshape,会报错:
tensorflow.python.framework.errors_impl.InvalidArgumentError: true_classes must have num_true columns, expected: 1 was: 5
"""
ids, _, _ = tf.nn.fixed_unigram_candidate_sampler(
true_classes=vec,
num_true=5,
num_sampled=2,
unique=False,
range_max=5,
vocab_file='',
distortion=1.0,
num_reserved_ids=0,
num_shards=1,
shard=0,
unigrams=(0.1, 0.2, 0.3, 0.1, 0.3),
# unigrams=(1, 2, 3, 1, 3),
)
with tf.Session() as sess:
print(sess.run(ids))
if __name__ == '__main__':
test1()
import tensorflow as tf
def test1():
vec = tf.constant([[1, 2, 3, 4, 6]], dtype=tf.int64)
vec = tf.reshape(vec, [-1, 1])
ids, _, _ = tf.nn.fixed_unigram_candidate_sampler(
true_classes=vec,
num_true=1,
num_sampled=2,
unique=False,
range_max=5,
vocab_file='',
distortion=1.0,
num_reserved_ids=0,
num_shards=1,
shard=0,
unigrams=(0.1, 0.2, 0.3, 0.1, 0.3),
# unigrams=(1, 2, 3, 1, 3),
)
with tf.Session() as sess:
print(sess.run(ids))
if __name__ == '__main__':
test1()
参考文章:
如何通过自定义概率分布在Tensorflow中进行采样?
勘误:
unigrams 可为float
https://www.thinbug.com/q/49713210
tf.nn.fixed_unigram_candidate_sampler解释
https://blog.csdn.net/u011026968/article/details/88537939
TensorFlow函数教程:tf.nn.fixed_unigram_candidate_sampler
https://www.w3cschool.cn/tensorflow_python/tf_nn_fixed_unigram_candidate_sampler.html