tf.multinomial()/tf.random.categorical()用法解析

本文深入解析了TensorFlow中tf.random.categorical()函数的使用方法,替代了已废弃的tf.multinomial(),并详细解释了如何根据未正规化的log概率进行样本抽取,通过实例演示了采样过程和结果。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

tf.multinomial()/tf.random.categorical()用法解析

首先说一下,tf.multinomial()在tensorflow2.0版本已经被移除,取而代之的就是tf.random.categorical()

网上的很多博客解释的都不清楚,官网......解释的也很模糊,于是想自我总结一下,顺便帮助对此也很困惑的人~

因为tf.multinomial()被tf.random.categorical()替代,所以下文以tf.random.categorical()为描述方式进行介绍

官网的解释

tf.random.categorical

从一个分类分布中抽取样本(tf.multinomial()是多项分布)

别名:

  • tf.compat.v1.random.categorical
  • tf.compat.v2.random.categorical
tf.random.categorical(
    logits,
    num_samples,
    dtype=None,
    seed=None,
    name=None
)

例子:

# samples has shape [1, 5], where each value is either 0 or 1 with equal
# probability.
samples = tf.random.categorical(tf.math.log([[10., 10.]]), 5)

参数:

  • logits: 形状为 [batch_size, num_classes]的张量. 每个切片 [i, :] 代表对于所有类的未正规化的log概率。
  • num_samples: 0维,从每一行切片中抽取的独立样本的数量。
  • dtype: 用于输出的整数类型,默认为int64。
  • seed: 一个Python整数,用于创建分布的随机种子。See tf.compat.v1.set_random_seedfor behavior.
  • name: 操作的可选名字

Returns:

形状为[batch_size, num_samples]的抽取样本.

个人理解

1. 这个函数的意思就是,你给了一个batch_size × num_classes的矩阵,这个矩阵是这样的:
每一行相当于log(p(x)),这里假设p(x)=[0.4,0.3,0.2,0.1],(p(x)的特性就是和为1),
然后再取log,那么log(p(x))就等于[-0.9162907 -1.20397282 -1.60943794 -2.30258512]
函数利用你给的分布概率,从其中的每一行中抽取num_samples次,最终形成的矩阵就是batch_szie × num_samples了。

2. 这里的抽样方法可以再详细解释一下,举个例子(请不要考虑真实性),给一行[1.0,2.0,2.0,2.0,6.0],采样4次,那么结果很大可能都是[4,4,4,4](不信可以试一下),因为下标为4的概率(6.0)远远高于其他的概率,当然也会出现比如[4,4,2,4]这样的情况,就是说其他的下标因为给定的概率就低,所以被采样到的概率也就低了。

3. 官网解释中logits,也就是你给的矩阵,每个切片 [i, :] 代表对于所有类的未正规化的log概率(即其和不为1),但必须是小数,就像官网的样例一样,就算是整数,后面也要加一个小数点,否则会报错。

4. 返回值是什么的问题,返回的其实不是抽取到的样本,而是抽取样本在每一行的下标。

为了能更加充分的理解,下面奉上一个小小的例子:

import tensorflow as tf;
for i in tf.range(10):
    samples = tf.random.categorical([[1.0,1.0,1.0,1.0,4.0],[1.0,1,1,1,1]], 6)
    tf.print(samples)

 输出结果

[[4 4 4 4 4 1]
 [3 1 3 0 4 3]]
[[4 0 4 4 4 1]
 [1 0 2 4 1 2]]
[[0 4 4 0 4 4]
 [3 0 0 1 1 4]]
[[4 4 4 4 4 0]
 [2 1 4 3 4 4]]
[[4 4 2 4 4 4]
 [1 3 1 0 4 0]]
[[4 4 4 4 4 4]
 [3 0 4 1 1 1]]
[[4 4 0 0 4 4]
 [3 3 0 3 2 2]]
[[1 4 4 4 4 4]
 [2 2 1 3 0 2]]
[[4 4 4 4 4 4]
 [2 4 4 3 2 2]]
[[4 4 4 4 3 4]
 [2 4 2 2 1 0]]

看到这估计你就能理解了,其中[[1.0,1.0,1.0,1.0,4.0],[1.0,1,1,1,1]]就是需要进行采样的矩阵,这里加小数点其实可以只加一个,只要让程序知道你用的是概率就行(当然实际都是通过tf.log()得到的不用手动输入),输出结果自然就是样本所在行的下标,多运行几次,就能更直观的感受到,设定的概率和采样结果之间的关系。(比如这里第一行的采样结果很多都是最后一个样本,第二行因为概率相同,采样结果就很均匀)

就这么多啦,如果文章有错误或者有疑问欢迎评论区交流呀(●’◡’●)ノ ~

评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值