引子
在学习各类Machine Learning方法时,免不了要与“分布”打交道。我们有时候需要计算某个分布的熵,有时候需要计算两个分布之间的交叉熵或KL散度。当然,这可以通过使用Numpy中的numpy.random.normal之类的函数来实现,但是我们更希望能够按照TensorFlow计算图的形式来实现,这样的话,可以更好地利用TensorFlow的一些优势(如一次性计算,共享计算结果等)。
简介
tf.distributions是TensorFlow提供的核心组件之一,用于实现一些常见的概率分布,并给出了一系列的辅助计算函数。首先,该组件中有Distribution基类、RegisterKL类、ReparameterizationType类。其中RegisterKL类是一个注册KL散度实现的装饰器,也即可以为某个分布添加KL散度的计算功能。此外,该组件还实现了以下分布:
- Bernoulli Distribution;
- Beta Distribution;
- Categorical Distribution;
- Dirichlet Distribution;
- Dirichlet-Multinomial Distribution;
- Exponential Distribution;
- Gamma Distribution;
- Laplace Distribution;
- Multinomial Distribution;
- Normal Distribution;
- StudentT Distribution;
- Uniform Distribution.
下面我们以Normal Distribution为例来进行介绍。
tf.distributions.Normal
Normal类型定义在tensorflow/python/ops/distributions/normal.py文件中。其__init__函数定义如下:
__init__(
loc,
scale,
validate_args=False,
allow_nan_stats=True,
name='Normal'
)
其中loc为高斯分布的均值
μ
\mu
μ,scale为标准差
σ
\sigma
σ。
在Normal类中,有如下properties:
- allow_nan_stats
- batch_shape
- dtype
- event_shape
- loc
- name
- parameters
- reparameterization_type
- scale
- validate_args
关于这些性质的解释就不赘述了。下面列出Normal类中给出的一些方法(列出来只是为了能够一目了然): - batch_shape_tensor (name=‘batch_shape_tensor’)
- cdf (value, name=‘cdf’)
- copy (**override_parameters_kwargs)
- covariance (name=‘covariance’)
- cross_entropy (other, name=‘cross_entropy’)
- entropy (name=‘entropy’)
- event_shape_tensor (name=‘event_shape_tensor’)
- is_scalar_batch (name=‘is_scalar_batch’)
- is_scalar_event (name=‘is_scalar_event’)
- kl_divergence (other, name=‘kl_divergence’)
- log_cdf (value, name=‘log_cdf’)
- log_prob (value, name=‘log_prob’)
- log_survival_function (value, name=‘log_survival_function’)
- mean (name=‘mean’)
- mode (name=‘mode’)
- param_shapes (cls, sample_shape, name=‘DistributionParamShapes’)
- param_static_shapes (cls, sample_shape)
- prob (value, name=‘prob’)
- quantile (value, name=‘quantile’)
- sample (sample_shape=(), seed=None, name=‘sample’)
- stddev (name=‘stddev’)
- survival_function (value, name=‘survival_function’)
- variance (name=‘variance’)
其中有计算熵的entropy方法,计算交叉熵的cross_entropy方法,计算KL散度(相对熵)的kl_divergence方法,这些方法为我们提供了极大的便利。
尾声
本文对tf.distributions进行了极简的介绍,大家如果对此有兴趣的话可以直接在TensorFlow官网查看,具体见:tf.distributions。
大家周五快乐~