tf.nn.fixed_unigram_candidate_sampler解释

https://www.tensorflow.org/api_docs/python/tf/random/fixed_unigram_candidate_sampler

 

上面链接是官网解释,看了一会儿感觉没看懂 跑了几个列子有点懂了。

本文结合https://www.w3cschool.cn/tensorflow_python/tf_nn_fixed_unigram_candidate_sampler.html    

给出更详细的解释如下:

tf.random.fixed_unigram_candidate_sampler(
    true_classes,
    num_true,
    num_sampled,
    unique,
    range_max,
    vocab_file='',
    distortion=1.0,
    num_reserved_ids=0,
    num_shards=1,
    shard=0,
    unigrams=(),
    seed=None,
    name=None
)

使用提供的(固定)基本分布对一组类进行采样.

该操作从整数范围[0,range_max]中随机采样num_sampled个类,所有的类的类别是[0, range_max), 每个类被采样的概率大小由参数unigrams指定,这个参数的值可以是概率的array,也可以是int的vector(表示出现次数,次数大表示被采样的概率大)

sampling_candidates的元素是在没有替换 (如果unique = True) 或替换 (如果unique = False) 的基础分布中绘制的. 

基本分布从文件中读取或作为内存中数组传入.还可以通过对权重应用distortion power(失真功率)来扭曲分布.

此外,此操作返回张量true_expected_count和sampled_expected_count,表示每个目标类(true_classes)和采样类(sampled_candidates)预期在平均张量的采样类中出现的次数.如果unique=True,则这些是拒绝后的概率,我们大致计算它们.

参数:

  • true_classes:一个int64类型的Tensor,具有shape [batch_size, num_true].目标类.
  • num_true:int,每个训练示例的目标类数.
  • num_sampled:int,随机抽样的类数.
  • unique:bool,确定批处理中的所有采样类是否都是唯一的.
  • range_max:int,可能的类数.
  • vocab_file:此文件中的每个有效行(应具有类似CSV的格式)对应于有效的单词ID.ID从num_reserved_ids开始按顺序排列.每行中的最后一个条目应该是对应于计数或相对概率的值.vocab_file和unigrams中的一个需要传递给此操作.
  • distortion:distortion(失真)用于扭曲unigram概率分布.在添加到内部unigram分布之前,首先将每个权重提升到失真的幂.结果,distortion = 1.0给出常规的unigram采样(由vocab文件定义),并且distortion = 0.0给出均匀分布.
  • num_reserved_ids:可选的,用户可以在范围[0, num_reserved_ids)内添加一些保留ID.一个用例是使用特殊的未知单词令牌作为ID 0.这些ID的抽样概率为0.
  • num_shards:采样器可用于从原始范围的子集中进行采样,以便通过并行性加速整个计算.此参数(与shard一起)表示在整体计算中使用的分区数.
  • shard:采样器可用于从原始范围的子集中进行采样,以便通过并行性加速整个计算.此参数(与num_shards一起)表示使用分区时操作的特定分区号.
  • unigrams:unigram计数或概率的列表,按顺序每个ID一个.应该将vocab_file和unigrams中的一个传递给此操作.可以是int可以是float
  • seed:int,特定于操作的种子.默认值为0.
  • name:操作的名称(可选).

返回:

  • sampled_candidates:int64类型和shape [num_sampled]的张量,抽样类.
  • true_expected_count:float类型的张量,shape与true_classes相同.每个true_classes的采样分布下的预期计数.
  • sampled_expected_count:float类型的张量.shape与sampled_candidates相同.每个sampled_candidates的采样分布下的预期计数.

测试例子


import tensorflow as tf


def test1():
    vec = tf.constant([[1, 2, 3, 4, 6]], dtype=tf.int64)
    # vec = tf.reshape(vec, [-1, 1])
    ids, _, _ = tf.nn.fixed_unigram_candidate_sampler(
        true_classes=vec,
        num_true=5,
        num_sampled=2,
        unique=False,
        range_max=5,
        vocab_file='',
        distortion=1.0,
        num_reserved_ids=0,
        num_shards=1,
        shard=0,
        unigrams=(0.1, 0.2, 0.3, 0.1, 0.3),
        # unigrams=(1, 2, 3, 1, 3),
    )
    # vs = ids(vec)


    with tf.Session() as sess:
        print sess.run(ids)


if __name__ == '__main__':
    test1()

输出

[4 0]

 

 

  • 2
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值