tensorflow2.x实现两个多元高斯分布之间的KL散度,很重要

0.背景

  • 现在假设你要用tensorflow计算两个多元高斯分布之间的KL散度,用闭式解,该如何用tensorflow2.x实现。
    在这里插片描述
  • 看到这个公式,相比大家都是头疼的,尤其在训练时候,还要考虑Batch的维度。今天就用tensorflow实现一下。

1. tensorflow矩阵操作

1.1 多维矩阵的乘法

  • 一般我们都考虑二维矩阵的乘法,只需要注意两个矩阵的维度即可。但是,有的时候,我们还需要考虑例如Batch_size怎么搞,这是这一小结要解决的问题。

1.1.1 tf.matmul函数

点我

  • a_is_sparse或者b_is_sparse只有在rank=2时才有用,不然会报错
  • 我们测试二维和三维的时候,看看这个操作是否只考虑用后两个维度进行矩阵乘积,而其它保持不变,并且维度大小相等
  • 要注意的是:两个矩阵的数据类型要一致,不然会报错
# 2-D tensor `a`
a = tf.constant([1, 2, 3, 4, 5, 6], shape=[2, 3]) => [[1. 2. 3.]
                                                      [4. 5. 6.]]
# 2-D tensor `b`
b = tf.constant([7, 8, 9, 10, 11, 12], shape=[3, 2]) => [[7. 8.]
                                                         [9. 10.]
                                                         [11. 12.]]
c = tf.matmul(a, b) shape=(2,2) => [[58 64]
                                   [139 154]]

# 3-D tensor `b`
b = tf.constant(np.arange(13, 25, dtype=np.int32),
                shape=[2, 3, 2])                   => [[[13. 14.]
                                                        [15. 16.]
                                                        [17. 18.]],
                                                       [[19. 20.]
                                                        [21. 22.]
                                                        [23. 24.]]]
c = tf.matmul(a, b) shape(2,2,2)=> [[[ 94 100]
                                    [229 244]],
                                    [[508 532]
                                    [697 730]]]
  • 结论:在进行矩阵乘积,我们说的时点乘,不是逐元素乘积(use “⭐”),使用matmul函数,该函数只考虑最后两个维度,其余维度不变同时要求两个矩阵的这些维度大小相等,如果我们想探究一下源码,OK:

  • 答案是,当矩阵rank>2时,会自动调用batch_mat_mul函数

    if (not a_is_sparse and
        not b_is_sparse) and ((a_shape is None or len(a_shape) > 2) and
                              (b_shape is None or len(b_shape) > 2)):
      # BatchMatmul does not support transpose, so we conjugate the matrix and
      # use adjoint instead. Conj() is a noop for real matrices.
      if transpose_a:
        a = conj(a)
        adjoint_a = True
      if transpose_b:
        b = conj(b)
        adjoint_b = True
      return gen_math_ops.batch_mat_mul(
          a, b, adj_x=adjoint_a, adj_y=adjoint_b, name=name)
 
 
# 上述代码中batch_mat_mul的定义
def batch_mat_mul(x, y, adj_x=False, adj_y=False, name=None):
  r"""Multiplies slices of two tensors in batches.
  Multiplies all slices of `Tensor` `x` and `y` (each slice can be
  viewed as an element of a batch), and arranges the individual results
  in a single output tensor of the same batch size. Each of the
  individual slices can optionally be adjointed (to adjoint a matrix
  means to transpose and conjugate it) before multiplication by setting
  the `adj_x` or `adj_y` flag to `True`, which are by default `False`.
  
  The input tensors `x` and `y` are 2-D or higher with shape `[..., r_x, c_x]`
  and `[..., r_y, c_y]`.
  
  The output tensor is 2-D or higher with shape `[..., r_o, c_o]`, where:
  
      r_o = c_x if adj_x else r_x
      c_o = r_y if adj_y else c_y
  
  It is computed as:
  
      output[..., :, :] = matrix(x[..., :, :]) * matrix(y[..., :, :])
  Args:
    x: A `Tensor`. Must be one of the following types: `bfloat16`, `half`, `float32`, `float64`, `int32`, `int64`, `complex64`, `complex128`.
      2-D or higher with shape `[..., r_x, c_x]`.
    y: A `Tensor`. Must have the same type as `x`.
      2-D or higher with shape `[..., r_y, c_y]`.
    adj_x: An optional `bool`. Defaults to `False`.
      If `True`, adjoint the slices of `x`. Defaults to `False`.
    adj_y: An optional `bool`. Defaults to `False`.
      If `True`, adjoint the slices of `y`. Defaults to `False`.
    name: A name for the operation (optional).
  Returns:
    A `Tensor`. Has the same type as `x`.
  """

1.1.2 使用 @ 重载函数

  • @ 函数是一个重载运算符号,以下是它的解释

  # Since python >= 3.5 the @ operator is supported (see PEP 465).
  # In TensorFlow, it simply calls the `tf.matmul()` function, so the
  # following lines are equivalent:
  d = a @ b @ [[10.], [11.]]
  d = tf.matmul(tf.matmul(a, b), [[10.], [11.]])

1.2 多维矩阵的转置

  • 设想以下,我们现在得到了一个均值向量μ,shape=(Batch,μ_dim),对于Batch中的每一个μ,我们都希望得到它的转置,即得到shape=(Batch,1,μ_dim)或者(Batch,μ_dim,1),我们要做的有两步:

1.2.1 用tf.expand_dims扩展维度

tf.expand_dims(a,axis=),该函数将指定的tensor a在指定的维度上增加一个维度,置为1,举个例子:

a = tf.ones(shape=(10,5))
a = tf.expand_dim(a,axis=-1)

print(a.shape)
>>> (10,5,1)

所以我们可以使用该函数将
(B,μ_dim) ------>(B,μ_dim,1),方便我们后续的矩阵操作

1.2.2 tf.squeeze(a,axis=)

  • 该函数是和tf.expand_dims相反的函数,去掉维度为1的维度
a = tf.ones(shape=(1,1,2,2))
a = tf.squeeze(a,# axis= 指定维度)
a.shape

# result
(2,2)

1.2.3 矩阵转置

方法一:使用tensorflow1.x版本的tf.compat.v1.matrix_transpose函数

  • 它会转置张量的最后两个维度,很方便,很适合在rank>2的情况下使用
  • tf.compat.v1是在使用tensorflow2.x时,用这个函数兼容tensorflow1.x的函数库,英文是compatiable version1
matrix_transpose(
    a,
    name='matrix_transpose'
)
# Matrix with no batch dimension.
# 'x' is [[1 2 3]
#         [4 5 6]]
tf.matrix_transpose(x) ==> [[1 4]
                                 [2 5]
                                 [3 6]]

Matrix with two batch dimensions.
x.shape is [1, 2, 3, 4]
tf.compat.v1.matrix_transpose(x) is shape [1, 2, 4, 3]

方法二:使用tensorflow2.x版本的tf.transpose()函数

建议参考下面这个博客:
https://blog.csdn.net/qq_40994943/article/details/85270159

1.3 求矩阵的行列式

  • 使用tf.compat.v1.matrix_determinant()函数
matrix_determinant(
    input,
    name=None
)

我们知道只有满秩的矩阵才有行列式,所以input必须满足形状为:[…, M, M],这个函数可以帮助我们计算一个batch的行列式。

a = tf.ones(shape=(5,4,4))
det_a = tf.compat.v1.matrix_determinant(a)
print(det_a)

# result
<tf.Tensor: shape=(5,), dtype=float32, numpy=array([0., 0., 0., 0., 0.], dtype=float32)>

1.4 求矩阵的逆

1.4.1 tf.matrix_inverse()

在这里插入图片描述

  • 在深度学习领域,我们一般都考虑特征变量之间的协方差为0,所以我们只需要得到矩阵的对角线元素,然后用对角线元素去构建矩阵即可。
  • 另外需要注意的是,在进行矩阵的逆操作时,要保证矩阵可逆

1.4.2 tf.compat.v1.matrix_diag()

2. 联合起来,就能更强

将上面的这些操作联合起来,我们就能计算出两个高斯分布之间的KL散度

def compute_kl(u1,sigma1,u2,sigma2,dim):
    """
    计算两个多元高斯分布之间KL散度KL(N1||N2);
    
    所有的shape均为(B1,B2,...,dim),表示协方差为0的多元高斯分布
    这里我们假设加上Batch_size,即形状为(B,dim)
    
    dim:特征的维度
    """
    sigma1_matrix = tf.compat.v1.matrix_diag(sigma1) # (B,dim,dim)
    sigma1_matrix_det = tf.compat.v1.matrix_determinant(sigma1_matrix) # (B,)
    
    sigma2_matrix = tf.compat.v1.matrix_diag(sigma2) # (B,dim,dim)
    sigma2_matrix_det = tf.compat.v1.matrix_determinant(sigma2_matrix) # (B,)
    sigma2_matrix_inv = tf.compat.v1.matrix_diag(1./sigma2) # (B,dim,dim)
    
    delta_u = tf.expand_dims((u2-u1),axis=-1) # (B,dim,1)
    delta_u_transpose = tf.compat.v1.matrix_transpose(delta_u) # (B,1,dim)
    
    term1 = tf.reduce_sum((1./sigma2)*sigma1,axis=-1) # (B,) represent trace term
    term2 = delta_u_transpose @ sigma2_matrix_inv @ delta_u  # (B,)
    term3 = -dim
    term4 = tf.math.log(sigma2_matrix_det) - tf.math.log(sigma1_matrix_det)
    
    KL = 0.5 * (term1 + term2 + term3 + term4)
    
    # if you want to compute the mean of a batch,then,
    KL_mean = tf.reduce_mean(KL)
    
    return KL_mean

# 测试
dim = 5
u1 = tf.zeros(shape=(10,5))
sigma1 = tf.ones(shape=(10,5))
u2 = tf.zeros(shape=(10,5))
sigma2 = tf.ones(shape=(10,5))

dim = 5
u1 , sigma1 = tf.zeros(shape=(10,5)),tf.ones(shape=(10,5)) # N(0,I)
u2 ,sigma2 = tf.zeros(shape=(10,5)),tf.ones(shape=(10,5))  # N(0,I)
u3 ,sigma3 = tf.zeros(shape=(10,5)),4*tf.ones(shape=(10,5)) # N(0,4I)

KL1 = compute_kl(u1,sigma1,u2,sigma2,dim)
KL2 = compute_kl(u1,sigma1,u3,sigma3,dim)

print(KL1,"\n",KL2,sep='')

# result
# tf.Tensor(0.0, shape=(), dtype=float32)
# tf.Tensor(1.5907359, shape=(), dtype=float32)
  • 0
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

InceptionZ

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值