tensorflow.contrib.distributions.RelaxedOneHotCategorical
RelaxedOneHotCategorical是随机概率向量上的分布,这些向量的正实值和为一,连续逼近一个OneHotCategorical。近似程度由温度控制:当温度变为0时,RelaxedOneHotCategorical变得离散,分布由logits或probs参数描述,当温度变为无穷大时,RelaxedOneHotCategorical变为常数分布,与常数向量(1/event_size,…,1/event_size)相同。
RelaxedOneHotCategorical分布同时作为Gumbel Softmax(Jang等人,2016, https://arxiv.org/abs/1611.01144)和Concrete(Maddison等人,2016)分布引入,用作分类单热分布的重新参数化连续近似。如果你使用这个分布,请引用这两篇论文:
[1] E. Jang, S. Gu, and B. Poole. Categorical reparameterization with gumbel-softmax. arXiv preprint arXiv:1611.01144, 2016.
[2] C. J. Maddison, D. Tarlow, and T. Minka, “A* sampling,” in Advances in Neural Information Processing Systems, 2014, pp. 3086– 3094.
Args
temperature: An 0-D Tensor
, representing the temperature of a set of RelaxedOneHotCategorical distributions. The temperature should be positive.
一个0-D的张量,表示RelaxedOneHotCategorical分布的温度,应该是正的。
logits:
An N-D Tensor
, N >= 1
, representing the log probabilities of a set of RelaxedOneHotCategorical distributions. The first N - 1
dimensions index into a batch of independent distributions and the last dimension represents a vector of logits for each class. Only one of logits
or probs
should be passed in.
第一个“N-1”维索引成一批独立的分布,最后一个维表示每个类的logit向量。logits和probs只能传入一个。
probs:
An N-D Tensor
, N >= 1
, representing the probabilities of a set of RelaxedOneHotCategorical distributions. The first N - 1
dimensions index into a batch of independent distributions and the last dimension represents a vector of probabilities for each class. Only one of logits
or probs
should be passed in.
第一个“N-1”维索引为一批独立分布,最后一个维表示每个类的概率向量。logits和probs只能传入一个
dtype:
The type of the event samples (default: inferred from logits/probs).
validate_args:
Unused in this distribution.
allow_nan_stats:
Python bool
, default True
. If False
, raise an exception if a statistic (e.g. mean/mode/etc…) is undefined for any batch member. If True
, batch members with valid parameters leading to undefined statistics will return NaN for this statistic.
name:
A name for this distribution (optional).
参考:https://tensorflow.google.cn/probability/api_docs/python/tfp/distributions/RelaxedOneHotCategorical
https://zhuanlan.zhihu.com/p/50065712