0. 引言
前几天分几篇博文精细地讲述了《von Mises-Fisher 分布》, 以及相应的 PyTorch 实现《von Mises-Fisher Distribution (代码解析)》, 其中以 Uniform 分布为例简要介绍了 torch.distributions
包的用法. 本以为已经可以了, 但这两天看到论文 The Power Spherical distribution 的代码, 又被其实现分布的方式所吸引.
Power Spherical 分布与 von Mises Fisher 分布类似, 只不过将后者概率密度函数中的指数函数换成了多项式函数: f p ( x ; μ , κ ) ∝ e x p ( κ μ ⊺ x ) ⇓ f p ( x ; μ , κ ) ∝ ( 1 + μ ⊺ x ) κ \begin{aligned} f_p(\bm{x}; \bm{\mu}, \kappa) &\propto exp(\kappa \bm{\mu}^\intercal \bm{x}) \\ &\Downarrow\\ f_p(\bm{x}; \bm{\mu}, \kappa) &\propto (1+\bm{\mu}^\intercal \bm{x})^\kappa \\ \end{aligned} fp(x;μ,κ)fp(x;μ,κ)∝exp(κμ⊺x)⇓∝(1+μ⊺x)κ 采样框架基本一致, 且这么做可以使边缘 t t t 的线性变换 t + 1 2 ∼ B e t a ( p − 1 2 + κ , p − 1 2 ) \frac{t+1}{2} \sim Beta(\frac{p-1}{2}+\kappa, \frac{p-1}{2}) 2t+1∼Beta(2p−1+κ,2p−1), 从而避免了接受-拒绝采样过程.
当然, 按照之前的 VonMisesFisher
的写法, 这个 t
的采样大概是这样:
z = beta.sample(sample_shape)
t = 2 * z - 1
但现在我遇到了这种写法:
class MarginalTDistribution(tds.TransformedDistribution):
arg_constraints = {
'dim': constraints.positive_integer,
'scale': constraints.positive,
}
has_rsample = True
def __init__(self, dim, scale, validate_args=None):
self.dim = dim
self.scale = scale
super().__init__(
tds.Beta( # 用 Beta 分布转换, z 服从 Beta(α+κ,β)
(dim - 1) / 2 + scale, (dim - 1) / 2, validate_args=validate_args
),
transforms=tds.AffineTransform(loc=-1, scale=2), # t=2z-1 是想要的边缘分布随机数
)
然后就可以进行对 t t t 的采样了.
架构大概是这样的: 一个基本分布类 distributions.Beta
和一个转换 transforms.AffineTransform
, 输入到 TransformedDistribution
的子类 MarginalTDistribution
中, 通过对一个
B
e
t
a
Beta
Beta 的线性转换, 实现边缘分布
t
t
t.
我们可以看到其基本架构, 本文将详细解析其内部的具体细节, 包括:
1. Distribution
在之前的 <von Mises-Fisher Distribution (代码解析)> 中, 已经通过 Uniform
简单介绍了 Distribution
的用法. 它是实现各种分布的抽象基类. 本文将以解析源码的方式详细介绍.
1.1 参数验证 validate_args
打开源码, 首先映入眼帘的是关于参数验证的代码:
# true if Python was not started with an -O option. See also the assert statement.
_validate_args = __debug__
@staticmethod
def set_default_validate_args(value: bool) -> None:
"""
设置 validation 是否开启.
validation 通常是耗时的, 所以最好在模型 work 后关闭它.
"""
if value not in [True, False]:
raise ValueError
Distribution._validate_args = value
Distribution
有一个类属性叫 _validate_args
, 默认值是 __debug__
(见附录1), 可以通过类静态方法 set_default_validate_args(value: bool)
来修改此值.
构造方法 __init__(...)
中的验证逻辑:
def __init__(self, ..., validate_args: Optional[bool]=None):
...
if validate_args is not None:
self._validate_args = validate_args
也就是说, 你可以在创建 Distribution
实例的时候设置是否进行参数验证. 如果不设置, 则按照类的属性 Distribution._validate_args
来.
if self._validate_args: # validate_args=False 就不用设置 arg_constraints 了
try: # 尝试获取字典 arg_constraints
arg_constraints = self.arg_constraints
except NotImplementedError: # 如果没设置, 则设置为 {}, 抛出警告
arg_constraints = {}
warnings.warn(...)
如果需要验证参数, 那么首先要获取一个叫 arg_constraints
的参数验证字典, 它列出了需要验证哪些参数. 这个抽象类里面并没有给出, 需要用户继承该类时写在子类中. 以 Uniform
为例:
class Uniform(Distribution):
...
arg_constraints = {
"low": constraints.dependent(is_discrete=False, event_dim=0),
"high": constraints.dependent(is_discrete=False, event_dim=0),
}
...
至于 constraints.dependent
是啥, 后面会详细介绍. 值得注意的是, 如果你在创建实例时指定 validate_args=False
, 那么所有关于参数验证的事就都不用管了.
for param, constraint in arg_constraints.items():
if constraints.is_dependent(constraint):
continue # skip constraints that cannot be checked
if param not in self.__dict__ and isinstance(
getattr(type(self), param), lazy_property
):
continue # skip checking lazily-constructed args
value = getattr(self, param) # 从当前对象获取参数 value
valid = constraint.check(value) # 检查参数值
if not valid.all(): # 检查不通过
raise ValueError(...)
这一段就是验证过程了, 包括:
- skip constraints that cannot be checked, 由
constraints.is_dependent(constraint)
判断是否可验证; - skip checking lazily-constructed args, 即参数名不在
self.__dict__
中, 并属于lazy_property
的跳过; - 获得参数, 进行验证;
具体的验证细节将在后面介绍.
1.2 batch_shape
& event_shape
除了 validate_args
参数, __init__(...)
方法中的另外两个参数就是:
def __init__(
self,
batch_shape: torch.Size = torch.Size(),
event_shape: torch.Size = torch.Size(),
):
self._batch_shape = batch_shape
self._event_shape = event_shape
...
这两个参数是啥? 在这个抽象类中, 我们看不到太多信息, 甚至 Uniform
中也只有 batch_shape = self.low.size()
的信息, 大概意思同时进行着一批的均匀分布, 如 low = torch.tensor([0.0, 1.0])
时, batch_shape = torch.Size([2])
, 表示一个二元的均匀分布. 看 MultivariateNormal
, 里面信息量较大:
batch_shape = torch.broadcast_shapes(
covariance_matrix.shape[:-2], # [:-2]是去掉了协方差矩阵的维度, 剩下的可能是 batch 的维度
loc.shape[:-1] # [:-1]是去掉了 envent 的维度, 剩下的可能是 batch 的维度
) # broadcast_shapes 意思是进行了广播, 如果 matrix 的 batch_shape 是 [2,1], loc 的 batch_shape 是 [1,2], 那么整个的 batch_shape 是广播后的 [2,2]
self.covariance_matrix = covariance_matrix.expand(batch_shape + (-1, -1)) # 之后 covariance_matrix 都被 expand 了
...
event_shape = self.loc.shape[-1:] # 看来就是样本的 shape
从这一段来看, batch_shape
是指创建的实例在进行多少个平行的基本分布, 而 event_shape
是指基本分布的事件(支撑点)维度. 如:
locs = torch.randn(2, 3)
matrixs = torch.randn(2, 3, 3)
covariance_matrixs = torch.bmm(matrixs, matrixs.transpose(1, 2))
normal = distributions.MultivariateNormal(loc=locs, covariance_matrix=covariance_matrixs)
print(normal.batch_shape) # 2
print(normal.event_shape) # 3
print(normal.sample())
##### output #####
torch.Size([2])
torch.Size([3])
tensor([[ 1.8972, -0.3961, -0.1530],
[-0.5018, -2.5110, 0.1293]])
batch 的意思还是那个 batch, 不过这里是指分布的 batch, 而不是数据的 batch. 采样时, 得到一批 samples, 对应每个分布.
还有一个 method 和这两个参数有关: expand
, 因为它是一个抽象 method, 基类中并没有实现, 那就直接看 MultivariateNormal
中的:
def expand(self, batch_shape: torch.Size, _instance=None):
"""
Args:
batch_shape (torch.Size): the desired expanded size.
_instance: new instance provided by subclasses that need to override `.expand`.
Returns:
New distribution instance with batch dimensions expanded to `batch_size`.
"""
new = self._get_checked_instance(MultivariateNormal, _instance)
batch_shape = torch.Size(batch_shape)
loc_shape = batch_shape + self.event_shape
cov_shape = batch_shape + self.event_shape + self.event_shape
new.loc = self.loc.expand(loc_shape)
new._unbroadcasted_scale_tril = self._unbroadcasted_scale_tril
if "covariance_matrix" in self.__dict__:
new.covariance_matrix = self.covariance_matrix.expand(cov_shape)
if "scale_tril" in self.__dict__:
new.scale_tril = self.scale_tril.expand(cov_shape)
if "precision_matrix" in self.__dict__:
new.precision_matrix = self.precision_matrix.expand(cov_shape)
super(MultivariateNormal, new).__init__(
batch_shape, self.event_shape, validate_args=False
)
new._validate_args = self._validate_args
return new
这个 method 会创建一个新的 instance 或调用的时候用户提供, 并设置 batch_shape
为参数提供的形状, 然后把参数 expand
到新的 batch_shape
. 用法:
mean = torch.randn(3)
matrix = torch.randn(3, 3)
covariance_matrix = torch.mm(matrix, matrix.t())
mvn = MultivariateNormal(mean, covariance_matrix)
bmvn = mvn.expand(torch.Size([2]))
print(bmvn.batch_shape)
print(bmvn.event_shape)
print(bmvn.sample())
##### output #####
torch.Size([2])
torch.Size([3])
tensor([[-4.0891, -4.2424, 6.2574],
[ 0.7656, -0.2199, -0.9836]])
[小结]: 关于 batch_shape
和 event_shape
的意义
Distribution
允许创建一批同类型分布,batch_shape
是指"批的形状", 如一批3
元MultivariateNormal
分布, 你需要提供shape=[batch_shape|3]
的一批3
维向量, 以及shape=[batch_shape|3,3]
的一批3x3
协方差矩阵.event_shape
就是支撑点的shape
, 如3
元MultivariateNormal
分布的支撑点的shape=(3,)
.
1.3 一些属性
包括: m e a n mean mean, m o d e mode mode, s t d std std, v a r i a n c e variance variance, e n t r o p y entropy entropy 等基本属性, 都需要用户在子类中自己实现. 还有一些相关的函数:
- cumulative density/mass function
cdf(value)
; - inverse cumulative density/mass function
icdf(value)
;
这个函数非常有用, Inverse Transform Sampling 中用其进行采样. 从 U ( 0 , 1 ) U(0,1) U(0,1) 中采样一个 u u u, 然后令 x = F − 1 ( u ) x = F^{-1}(u) x=F−1(u) 就是所求随机变量 X X X 的一个采样. - log of the probability density/mass function
log_prob(value)
, 对数概率.
注意, 目前看到的只有 log_prob
, 并没有 prob
, 一些示例要么只算 log_prob
, 要么计算后通过 exp(log_prob)
得到 prob
.
2. constraints.Constraint
前面在1.1参数验证中已经遇到 constraints.dependent(is_discrete=False, event_dim=0)
和 constraint.check(value)
, 但没有讲具体细节. 本节将详细剖析.
2.1 抽象基类 Constraint
先看源码:
class Constraint:
"""
一个 constraint 对象, 表示变量在某区域内有效, 即变量可优化的范围.
"""
is_discrete = False # Default to continuous.
event_dim = 0 # Default to univariate.
def check(self, value):
"""
结果的形状为"sample_shape + batch_shape", 指示 each event 值是否满足此限制.
"""
raise NotImplementedError
这是抽象基类 Constraint
, 比较简单, 只有两个类属性和一个 method check(value)
. is_discrete
表示待验证值是否为离散; 联想前面的 event_shape
, 大概可以知道 event_dim
是指 len(event_shape)
.(不过目前看只是为了验证参数, 还能验证采样的 event?)
2.2 _Dependent()
不被验证
这个基类信息太少, 对我们理解前面的内容毫无用处, 还是直接观察一些子类吧. 从 dependent = _Dependent()
开始, 它是 constraints.py
中定义好的 placeholder(这个倒是可以学一学):
class _Dependent(Constraint): # 看"_", 应该是不希望用户直接创建实例
"""
Placeholder for variables whose support depends on other variables.
These variables obey no simple coordinate-wise constraints.
"""
def __init__(self, *, is_discrete=NotImplemented, event_dim=NotImplemented):
self._is_discrete = is_discrete
self._event_dim = event_dim
super().__init__()
def __call__(self, *, is_discrete=NotImplemented, event_dim=NotImplemented):
"""
Support for syntax to customize static attributes::
constraints.dependent(is_discrete=True, event_dim=1)
"""
if is_discrete is NotImplemented: # 未提供就是默认
is_discrete = self._is_discrete
if event_dim is NotImplemented:
event_dim = self._event_dim
return _Dependent(is_discrete=is_discrete, event_dim=event_dim)
def check(self, x):
raise ValueError("Cannot determine validity of dependent constraint")
闹了半天, 我们并不能看到 constraints.dependent(is_discrete=False, event_dim=0)
有什么卵用, 只知道 “Cannot determine validity of dependent constraint”, 这也呼应了前面的:
if constraints.is_dependent(constraint):
continue # skip constraints that cannot be checked
也就是说, dependent
类型的限制是不会执行参数验证的. 那这个 _Dependent
到底有何用处? 先不管了.
2.3 _IndependentConstraint
重新解释 event_dim
我们看点复杂的, MultivariateNormal.arg_constraints
:
arg_constraints = {
"loc": constraints.real_vector,
"covariance_matrix": constraints.positive_definite,
"precision_matrix": constraints.positive_definite,
"scale_tril": constraints.lower_cholesky,
}
这些都是 constraints.py
中定义好的实例, 对于大多情况, 这些预定义好的实例已经够用, 但如果需要, 你也可以自定义. 先看 real_vector
:
independent = _IndependentConstraint
real_vector = independent(real, 1)
class _IndependentConstraint(Constraint):
"""
封装一个 constraint, 通过 aggregating over ``reinterpreted_batch_ndims``-many dims in :meth:`check`,
an event is valid 当且仅当它依赖的所有 entries 是 valid 的.
"""
def __init__(self, base_constraint, reinterpreted_batch_ndims):
self.base_constraint = base_constraint
self.reinterpreted_batch_ndims = reinterpreted_batch_ndims
super().__init__()
@property
def event_dim(self):
# real.event_dim 是 0, + real_vector(reinterpreted_batch_ndims=1) = 1
return self.base_constraint.event_dim + self.reinterpreted_batch_ndims
def check(self, value):
result = self.base_constraint.check(value) # 首先要符合 base.check
if result.dim() < self.reinterpreted_batch_ndims:
# 给 batch 留够 dim
expected = self.base_constraint.event_dim + self.reinterpreted_batch_ndims
raise ValueError(
f"Expected value.dim() >= {expected} but got {value.dim()}"
)
result = result.reshape( # 减掉 event
result.shape[: result.dim() - self.reinterpreted_batch_ndims] + (-1,)
)
result = result.all(-1) # 减少一个 dim
return result
意思很明了了, real_vector
是依赖于 real
(base_constraint) 的, reinterpreted_batch_ndims=1
是说把原来 value
的 batch_dim
重新解释, 分出 n
个给 event_dim
: 加上 reinterpreted_batch_ndims
, 比如
value = [[1, 2, 3],
[4, 5, 6]]
本来 real
的 event_dim=0
, 验证结果为(sample_shape + batch_shape = (2,2)
):
value = [[True, True, True],
[True, True, True]]
现在重新解释为 event_dim=1
, 验证结果为:
result = result.reshape( # 减掉 event
result.shape[: result.dim() - self.reinterpreted_batch_ndims] + (-1,) # (-1,) 表示新 event 内的所有 entries 展平
)
result = result.all(-1) # 新 event 内的所有 entries 为 True, 则新 event 为 True
================>
value = [True, True]
3. Transform
& _InverseTransform
上一节介绍了 constraints.Constraint
, 明白了在构建 Distribution
实例时进行的参数验证, 以保证用户提供的参数符合要求. 但还留下了一个疑问: Constraint
中的 event_dim
是指 len(event_shape)
, 难道还能验证采样的 event? 再者, check(value)
返回值的形状是 sample_shape + batch_shape
, 进一步说明它是会被用于采样结果检查的. 让我们看一看能否在 Transform
中找到答案.
Transform
& _InverseTransform
是一对互逆的操作, 实现从一个分布到另一个分布的转换. 这很有用, 因为 distributions
包已经实现了很多常见分布和转换, 自由组合威力巨大. 本节将详细介绍它是如何实现对分布的转换的.
[注] 从 _InverseTransform
的_
来看, 是不需要用户了解它的.
3.1 抽象类 Transform
的基本信息
class Transform:
"""
变换的抽象基类, 子类应该实现 one or both of `_call` or `_inverse`.
如果 `bijective=True`, 则必须实现 `log_abs_det_jacobian`.
Args:
cache_size (int): If one, the latest single value is cached.
Only 0 and 1 are supported.
"""
bijective = False # Transform 是否双射, 默认 False
domain: constraints.Constraint # 有效输入范围
codomain: constraints.Constraint # 有效输出范围
def __init__(self, cache_size=0):
self._cache_size = cache_size
self._inv = None
if cache_size == 0:
pass # default behavior
elif cache_size == 1:
self._cached_x_y = None, None
else:
raise ValueError("cache_size must be 0 or 1")
super().__init__()
果然, Transform
中有 Constraint
的, 分别是 domain
和 codomain
, 用于其检查输入输出是否符合要求. 此外, 还有 bijective
和 cache_size
这两个信息, 等一下看后面怎么说.
3.2 AffineTransform
抽象类的基本信息不多, 还是要看一个简单的例子: AffineTransform
, 线性变换.
class AffineTransform(Transform):
bijective = True
def __init__(self, loc, scale, event_dim=0, cache_size=0):
super().__init__(cache_size=cache_size)
self.loc = loc
self.scale = scale
self._event_dim = event_dim
线性变换是可逆的, 可以看到它的 bijective = True
. 参数是
y
=
l
o
c
+
s
c
a
l
e
×
x
y = loc + scale × x
y = loc + scale × x 中的 loc
和 scale
; event_dim
则是用于构建 domain
和 codomain
:
@constraints.dependent_property(is_discrete=False)
def domain(self):
if self.event_dim == 0:
return constraints.real
return constraints.independent(constraints.real, self.event_dim)
@constraints.dependent_property(is_discrete=False)
def codomain(self):
if self.event_dim == 0:
return constraints.real
return constraints.independent(constraints.real, self.event_dim)
即, domain
和 codomain
被限制为 event_dim
维向量, 默认是 0
, 输入输出皆为标量.
变换过程
def _call(self, x):
"""
Method to compute forward transformation.
"""
return self.loc + self.scale * x
def _inverse(self, y):
"""
Method to compute inverse transformation.
"""
return (y - self.loc) / self.scale
由于是双射, 还要实现:
def log_abs_det_jacobian(self, x, y):
shape = x.shape
scale = self.scale
if isinstance(scale, numbers.Real):
result = torch.full_like(x, math.log(abs(scale)))
else:
result = torch.abs(scale).log()
if self.event_dim:
result_size = result.size()[: -self.event_dim] + (-1,)
result = result.view(result_size).sum(-1)
shape = shape[: -self.event_dim]
return result.expand(shape)
计算结果的形状调整为 x
中除 event_dim
以外的形状, 即 sample_shape + batch_shape
. 至于为什么要这么做, 还需要看 TransformedDistribution
中具体的转换流程.
但这里有个问题, 假设 event_dim=1
, 输入的 x.shape=(2,3)
, 而 scale=2.0
和 scale=torch.tensor(2.0)
的计算结果是不一致的:
====================== scale=2.0 ==========================
result = torch.full_like(x, math.log(abs(2.0)))
[[log(2), log(2), log(2)],
[log(2), log(2), log(2)]]
result_size = (2,3)[: -1] + (-1,) = (2,3)
result = [3log(2), 3log(2)].expand([2]) = [3log(2), 3log(2)]
================== scale=tensor(2.0) =======================
result = torch.abs(scale).log() = log(2)
result_size = ()[: -1] + (-1,) = (-1,)
result = log(2).expand([2]) = [log(2), log(2)]
类似的, 只要 scale
是 tensor
, 并出现了计算广播, 就会出现这种情况. 不知道会不会造成计算错误, 看了后面的 TransformedDistribution
就能知道. 现在只能暂时不管了.
经测验, 果然是有问题的:
mn = distributions.MultivariateNormal(torch.tensor([0., 0., 0.]), torch.eye(3))
amn1 = distributions.TransformedDistribution(
mn,
[distributions.AffineTransform(torch.tensor([1., 2., 3.]), torch.tensor(2.0), event_dim=1)]
)
amn2 = distributions.TransformedDistribution(
mn,
[distributions.AffineTransform(torch.tensor([1., 2., 3.]), 2.0, event_dim=1)]
)
x = mn.sample(torch.Size([2]))
xs = x * 2.0 + torch.tensor([1., 2., 3.])
print(amn1.log_prob(xs))
print(amn2.log_prob(xs))
########## output ###########
tensor([0.6931, 0.6931]) # 输出的 log_abs_det_jacobian for scale=torch.tensor(2.0)
tensor([-4.6540, -4.7094])
tensor([2.0794, 2.0794]) # 输出的 log_abs_det_jacobian scale=2.0
tensor([-6.0403, -6.0957])
计算的概率密度也不一致, 理论上应该是后者对. 所以, 源代码编写者的原意可能是: 只要 scale
是 tensor
, 那么我就按你是不 broadcast
的.
3.3 TransformedDistribution
3.3.1 基本信息
class TransformedDistribution(Distribution):
"""
Extension of the Distribution class, which applies a sequence of Transforms
to a base distribution.
"""
arg_constraints: Dict[str, constraints.Constraint] = {}
def __init__(self, base_distribution, transforms, validate_args=None):
>>> 单 transfrom 变成 [transfrom], 再检查是否符合 transforms: List[Transform] <<<
它是对 Distribution
的扩展, 对一个 base distribution
实施一连串的 Transforms
:
X ~ BaseDistribution
Y = f(X) ~ TransformedDistribution(BaseDistribution, f)
log p(Y) = log p(X) + log |det (dX/dY)|
一个简单的例子:
# #################################
# Building a Logistic Distribution
# X ~ Uniform(0, 1)
# f = a + b * logit(X)
# Y ~ f(X) ~ Logistic(a, b)
# #################################
base_distribution = Uniform(0, 1)
transforms = [SigmoidTransform().inv, AffineTransform(loc=a, scale=b)]
logistic = TransformedDistribution(base_distribution, transforms)
其中 l o g i t ( x ) = l o g x 1 − x logit(x) = log\frac{x}{1-x} logit(x)=log1−xx 是 s i g m o i d sigmoid sigmoid 函数的逆.
下面是 TransformedDistribution
的 __init__(...)
内容(省略了开头将单 Transform
转换为列表以及检查类型的代码):
# >>> Reshape base_distribution according to transforms. >>>
# >>> 获取 base_distribution 的 batch_shape 和 event_shape 以及 event_dim >>>
base_shape = base_distribution.batch_shape + base_distribution.event_shape
base_event_dim = len(base_distribution.event_shape) # 的基本 shape
# <<< 获取 base_distribution 的 batch_shape 和 event_shape 以及 event_dim <<<
# 将 transforms 组合成一个 transform
transform = ComposeTransform(self.transforms)
# 先正向传播 shape, 再反向传播 shape, 一来一回 shape 不一致, 说明途中发生了广播
# 具体例子可为: 线性转换中的 [1,2,3] * [[2],[3]], 输入向量输出矩阵(再反向也是矩阵)
forward_shape = transform.forward_shape(base_shape)
expanded_base_shape = transform.inverse_shape(forward_shape)
if base_shape != expanded_base_shape: # 不一致说明发生了广播 (AffineTransform为例)
base_batch_shape = expanded_base_shape[
: len(expanded_base_shape) - base_event_dim
] # 干脆先把 base_distribution 给 expand 了
# 如 base_shape = batch_shape + event_shape = (,) + (,3) = (3,)
# expanded_base_shape = (2,3), 则 base_batch_shape = (2,)
base_distribution = base_distribution.expand(base_batch_shape) # 结果 base_shape = (2,3)
# transform.domain.event_dim 是指所有 transforms 中最大的 domain.event_dim (这个 domain.event_dim 可能就只是为了检查 dim 是否够用)
reinterpreted_batch_ndims = transform.domain.event_dim - base_event_dim
if reinterpreted_batch_ndims > 0:
base_distribution = Independent( # 但却实实在在地调整了 base_distribution 的 event_dim
base_distribution, reinterpreted_batch_ndims
) # 参考前面讲的 _IndependentConstraint
self.base_dist = base_distribution
这一部分的主旋律是 Reshape base_distribution according to transforms. 也就是说, self.base_dist
被赋予的是调整过的 base_distribution
. 主要包括:
- 调整
batch_shape
, bybase_distribution.expand(base_batch_shape)
, 前面讲过expand
; - 调整
event_shape
, byIndependent
, 这个类似前面讲的_IndependentConstraint
, 只不过这里是对Distribution
操作;
然而基类
Distribution
默认的expand
未实现, 所以如果预期Transforms
链中间会有接口不兼容的时候, 要注意实现expand
, 否则出错. 不过出错也好, 让你知道不兼容, 从而减少未知性.
具体过程看注释. 所以, 使用这种方式建立新的 Distribution
时, 要同时注意 base_distribution
和 transforms
的 event_dim
, 这对 log_prob
的计算有影响, 且 base_distribution
的 event_dim
可能被更改.
安排好 self.base_dist
后, 开始计算本 TransformedDistribution
的 batch_shape
和 event_shape
.
# Compute shapes.
transform_change_in_event_dim = ( # transform 导致的 event_dim 变化
transform.codomain.event_dim - transform.domain.event_dim
)
event_dim = max(
transform.codomain.event_dim, # the transform is coupled
base_event_dim + transform_change_in_event_dim, # the base dist is coupled
)
assert len(forward_shape) >= event_dim
cut = len(forward_shape) - event_dim # forward_shape 劈开
batch_shape = forward_shape[:cut]
event_shape = forward_shape[cut:]
super().__init__(batch_shape, event_shape, validate_args=validate_args)
3.3.2 采样
def sample(self, sample_shape=torch.Size()):
with torch.no_grad():
x = self.base_dist.sample(sample_shape)
for transform in self.transforms:
x = transform(x)
return x
def rsample(self, sample_shape=torch.Size()):
x = self.base_dist.rsample(sample_shape)
for transform in self.transforms:
x = transform(x)
return x
3.3.3 log_prob
需要 log_abs_det
def log_prob(self, value):
if self._validate_args: # 验证样本的就在此处了
self._validate_sample(value)
event_dim = len(self.event_shape)
log_prob = 0.0
y = value
for transform in reversed(self.transforms): # 倒着来
x = transform.inv(y) # 逆变换得到 x, 想计算 `log_prob`, 逆变换就得实现.
event_dim += transform.domain.event_dim - transform.codomain.event_dim
log_prob = log_prob - _sum_rightmost(
transform.log_abs_det_jacobian(x, y),
event_dim - transform.domain.event_dim,
)
y = x
log_prob = log_prob + _sum_rightmost(
self.base_dist.log_prob(y), event_dim - len(self.base_dist.event_shape)
)
return log_prob
根据 f y ( y ) = f X ( x ) / ∣ d y d x ∣ f_y(y) = f_X(x)/|\frac{dy}{dx}| fy(y)=fX(x)/∣dxdy∣, 比较容易理解 l o g f y ( y ) = l o g f X ( x ) − l o g ∣ d y d x ∣ logf_y(y) = logf_X(x) - log|\frac{dy}{dx}| logfy(y)=logfX(x)−log∣dxdy∣, 那么代码中大概的框架是连续减 l o g ∣ d y d x ∣ log|\frac{dy}{dx}| log∣dxdy∣, 到这没什么问题. 问题就在于为什么要执行:
_sum_rightmost(
transform.log_abs_det_jacobian(x, y),
event_dim - transform.domain.event_dim,
)
把 event_dim - transform.domain.event_dim
个最右侧的维度加起来? 空想不好理解, 必须举个例子, 还是以 AffineTransform
为例:
当 event_dim > transform.domain.event_dim
假设有:
value = [[1, 2, 3],
[4, 5, 6]], event_dim=1
affine = AffineTransform(1, 2, event_dim=0) # transform.domain.event_dim = 0
则计算的 log_abs_det
为:
[[log2, log2, log2],
[log2, log2, log2]]
但按照 event_dim=1
, 即基本的 event
单元为 [1, 2, 3]
和 [4, 5, 6]
, 对应的
l
o
g
∣
d
y
d
x
∣
log|\frac{dy}{dx}|
log∣dxdy∣ 为 [3log2, 3log2]
, 这就用到了上面说的 _sum_rightmost
, 把 event_dim - transform.domain.event_dim
个最右侧的维度加起来.
而在在一连串的 transforms
中, event_dim += transform.domain.event_dim - transform.codomain.event_dim
代表着当前 x
的 event_dim
.
event_dim = transform.domain.event_dim
时表示刚刚好, 输入符合 transform.codomain.event_dim
的要求. 有没有可能 event_dim < transform.domain.event_dim
? 怕是不能!
4. 实战解析及解惑
在读了 torch.distributions
包的源码之后, 让我们回到论文 The Power Spherical distribution 的代码, 再一看则豁然开朗. 包括边缘变量
t
t
t 和均匀子球的采样
v
\bm{v}
v 的组合操作, 以及组合后的 Householder 变换. 详情请参阅《von Mises-Fisher 分布》.
4.1 边缘变量 t t t 和均匀子球的采样 v \bm{v} v 的组合操作
class _TTransform(tds.Transform):
"""
大概就是 cat(t,v) 的吧, 注意, t 在开头
"""
# 设置为向量还是有必要的, 因为传递过程中的 log_abs_det_jacobian 计算要用到 event_dim
# real.event_dim=0, real_vector.event_dim=1
# 关系到传播中 log_prob 的计算
domain = constraints.real # 输入空间是实数
codomain = constraints.real # 输出空间也是实数
def _call(self, x):
t = x[..., 0].unsqueeze(-1)
v = x[..., 1:]
return torch.cat([t, v * torch.sqrt(torch.clamp(1 - t ** 2, _EPS))], -1)
def _inverse(self, y):
t = y[..., 0].unsqueeze(-1)
v = y[..., 1:]
return torch.cat([t, v / torch.sqrt(torch.clamp(1 - t ** 2, _EPS))], -1)
def log_abs_det_jacobian(self, x, y):
"""
计算变换后的分布的概率密度时有用 fY(y) = fX(x(y))|dx/dy|
:param x: input
:param y: output
:return: the log det jacobian log |dy/dx| given input and output
"""
t = x[..., 0]
# return ((x.shape[-1] - 3) / 2) * torch.log(torch.clamp(1 - t ** 2, _EPS)) # 怎么感觉是 (d-1)/2?
return ((x.shape[-1] - 1) / 2) * torch.log(torch.clamp(1 - t ** 2, _EPS))
_call
和 _inverse
都是对的, 但感觉 log_abs_det_jacobian
有问题. 首先我们从数学上先推导一下这个变换的
d
y
d
x
\frac{dy}{dx}
dxdy. 设
[
t
,
1
−
t
2
v
1
,
⋯
,
1
−
t
2
v
m
−
1
]
=
t
t
r
a
n
s
f
o
r
m
(
[
t
,
v
1
,
⋯
,
v
m
−
1
]
)
[t, \sqrt{1-t^2}v_1, \cdots, \sqrt{1-t^2}v_{m-1}] = ttransform([t, v_1, \cdots, v_{m-1}])
[t,1−t2v1,⋯,1−t2vm−1]=ttransform([t,v1,⋯,vm−1]), 其中
m
m
m 是向量的维度. 那么雅可比矩阵为:
J
=
[
1
0
⋯
0
−
t
v
1
1
−
t
2
1
−
t
2
⋮
0
⋮
⋮
⋱
⋮
−
t
v
m
−
1
1
−
t
2
0
⋯
1
−
t
2
]
J = \begin{bmatrix} 1 & 0 & \cdots & 0 \\ \frac{-tv_1}{\sqrt{1-t^2}} & \sqrt{1-t^2} & \vdots & 0 \\ \vdots & \vdots & \ddots & \vdots \\ \frac{-tv_{m-1}}{\sqrt{1-t^2}}& 0 & \cdots & \sqrt{1-t^2} \end{bmatrix}
J=
11−t2−tv1⋮1−t2−tvm−101−t2⋮0⋯⋮⋱⋯00⋮1−t2
则
d
e
t
(
J
)
=
(
1
−
t
2
)
m
−
1
2
det(J) = (1-t^2)^\frac{m-1}{2}
det(J)=(1−t2)2m−1, 那么
l
o
g
∣
d
y
d
x
∣
=
m
−
1
2
l
o
g
(
1
−
t
2
)
log|\frac{dy}{dx}| = \frac{m-1}{2}log(1-t^2)
log∣dxdy∣=2m−1log(1−t2). 我们看一看代码计算结果是什么, 值是刚才计算的值, 形状为 x.shape[:-1]
. 那么问题来了, 如果不管 domain.event_dim
, 这个计算结果还是对的, 但这里的 domain.event_dim=real.event_dim=0
, 而根据 TransformedDistribution
中的 # Compute shapes
, 可计算得到 self.event_dim=1
, 此时 event_dim - transform.domain.event_dim=1
, 你返回这个值后, 计算 log_prob
会再执行:
_sum_rightmost(
transform.log_abs_det_jacobian(x, y),
event_dim - transform.domain.event_dim
)
从而导致其在 sample_shape[-1]
上相加, 减少一个维度, 这肯定是不对的. 假设我们构造分布:
mn = distributions.MultivariateNormal(torch.zeros(2, 3), torch.eye(3).expand(2, 3, 3))
tmn = distributions.TransformedDistribution(mn, [_TTransform()])
x = tmn.sample(torch.Size([]))
print(x)
print(tmn.log_prob(x))
########## output ##########
tensor([[ 0.1690, -0.2441, -1.0599],
[-0.2891, 0.4383, 2.4503]])
tensor([-3.2637, -6.0629])
哎! 看着没问题啊, 但实际上此时 _TTransform.log_abs_det_jacobian
得到的是 shape=(2,)
的 tensor
, 然后被 _sum_rightmost
减少了一维, 得到一个 scalar
, 再加 MultivariateNormal.log_prob
(shape=[2]) 时是可广播的, 故而不会报错(实际上已经计算错误). 我们找到 TransformedDistribution.log_prob
, 添加一些 print
打印信息:
def log_prob(self, value):
"""
Scores the sample by inverting the transform(s) and computing the score
using the score of the base distribution and the log abs det jacobian.
"""
if self._validate_args:
self._validate_sample(value)
event_dim = len(self.event_shape)
log_prob = 0.0
y = value
for transform in reversed(self.transfrms):
x = transform.inv(y)
event_dim += transform.domain.event_dim - transform.codomain.event_dim
print(event_dim)
print(event_dim - transform.domain.event_dim)
print(transform.log_abs_det_jacobian(x, y))
log_prob = log_prob - _sum_rightmost(
transform.log_abs_det_jacobian(x, y),
event_dim - transform.domain.event_dim,
)
y = x
print(log_prob)
log_prob = log_prob + _sum_rightmost(
self.base_dist.log_prob(y), event_dim - len(self.base_dist.event_shape)
)
return log_prob
输出:
tensor([[-3.0071e+00, 6.7510e-19, -3.1936e-18],
[-3.2723e-01, 1.4745e+00, 7.2704e-01]])
event_dim: 1
event_dim - transform.domain.event_dim: 1
transform.log_abs_det_jacobian(x, y): tensor([-82.8931, -0.1133])
log_prob: 83.0063247680664
tensor([70.4010, 78.6826])
果然不出所料, 在 sample_shape[-1]
上相加, 减少一个维度, 这肯定是不对的. 其实也不一定叫 sample_shape
, 一般采样的 shape = sample_shape + batch_shape + event_shape
, 是在 sample_shape + batch_shape
的最后一维执行 sum
. 假设 shape = sample_shape([4]) + batch_shape([2]) + event_shape([3])
:
mn = distributions.MultivariateNormal(torch.zeros(2, 3), torch.eye(3).expand(2, 3, 3))
tmn = distributions.TransformedDistribution(mn, [_TTransform()])
x = tmn.sample(torch.Size([4]))
print(x)
print(tmn.log_prob(x))
########## output ##########
tensor([[[ 2.7956e-01, -7.6880e-01, -1.5509e+00],
[ 8.0168e-01, 1.6566e-01, 2.4777e-02]],
[[-8.6095e-01, 1.9766e-01, -2.6489e-01],
[-1.0264e+00, -2.0986e-18, -7.3761e-19]],
[[-9.9605e-02, 2.5579e-02, -9.6922e-02],
[ 2.0281e+00, 1.4941e-18, 5.9761e-19]],
[[-7.4705e-01, -1.7805e+00, 2.0846e-01],
[ 8.1492e-01, -5.5181e-01, 6.8105e-03]]])
event_dim: 1
event_dim - transform.domain.event_dim: 1
transform.log_abs_det_jacobian(x, y): tensor([[-8.1379e-02, -1.0292e+00],
[-1.3518e+00, -8.2893e+01],
[-9.9707e-03, -8.2893e+01],
[-8.1663e-01, -1.0909e+00]])
log_prob: tensor([ 1.1105, 84.2449, 82.9030, 1.9075])
Traceback (most recent call last):
File "/root/PycharmProjects/OptimalTransport/EBSW/ColorTransfer/zz.py", line 35, in <module>
print(tmn.log_prob(x))
File "/opt/programs/pytorch/torch/distributions/transformed_distribution.py", line 177, in log_prob
log_prob = log_prob + _sum_rightmost(
RuntimeError: The size of tensor a (4) must match the size of tensor b (2) at non-singleton dimension 1
transform.log_abs_det_jacobian(x, y)
的 shape=(4,2)
也即 sample_shape([4]) + batch_shape([2])
, 然而却把 batch_shape([2])
整没了, 然后和
_sum_rightmost(
self.base_dist.log_prob(y), event_dim - len(self.base_dist.event_shape) # =0
)
的 shape=(4,2)
产生冲突, 报错.
即使你转化诸如 Normal
等 event_dim=0
的 Distribution
, 也会出现 shape
不匹配的情况(只有二维样本时形成 scalar
不会报错, 但也计算错误).
[注] 为什么明明是
l
o
g
∣
d
y
d
x
∣
=
m
−
1
2
l
o
g
(
1
−
t
2
)
log|\frac{dy}{dx}| = \frac{m-1}{2}log(1-t^2)
log∣dxdy∣=2m−1log(1−t2), 代码中却写 (x.shape[-1] - 3) / 2
?
答曰: 不清楚.
[小结]: domain
和 codomain
的限制维度很重要, 该是 real_vector
的写成 real
会出现错误.
4.2 Householder Transform
class _HouseholderRotationTransform(tds.Transform):
"""
完成拼接后, 要进行 HouseholderRotation
"""
domain = constraints.real
codomain = constraints.real
def __init__(self, loc: torch.Tensor):
super().__init__()
e1 = torch.zeros_like(loc) # 继承
e1[..., 0] = 1.0
self.__u = tn_func.normalize(e1 - loc, dim=-1)
def _call(self, x: torch.Tensor):
return x - 2 * (x * self.__u).sum(-1, keepdim=True) * self.__u
def _inverse(self, y: torch.Tensor): # 逆变换是一样的
return y - 2 * (y * self.__u).sum(-1, keepdim=True) * self.__u
def log_abs_det_jacobian(self, x: torch.Tensor, y: torch.Tensor):
# h = torch.eye(x.shape[-1], device=x.device) - 2 * torch.outer(self.__u, self.__u)
# torch.log(torch.abs(torch.det(h)))
return 0.0 # 因为 |y|=|x|, 所以 |h|=1; 正交矩阵
具体原理见《householder 变换》. 现在我们关注 log_abs_det_jacobian
.
y
=
(
I
−
2
u
u
⊺
)
x
\bm{y} = (I - 2\bm{u}\bm{u}^\intercal) \bm{x}
y=(I−2uu⊺)x, 所以雅可比矩阵为
J
=
I
−
2
u
u
⊺
J = I - 2\bm{u}\bm{u}^\intercal
J=I−2uu⊺,
∣
d
e
t
(
I
−
2
u
u
⊺
)
∣
=
1
|det(I - 2\bm{u}\bm{u}^\intercal)| = 1
∣det(I−2uu⊺)∣=1,
l
o
g
∣
d
e
t
(
J
)
∣
=
0
log|det(J)| = 0
log∣det(J)∣=0.
同样, 值的计算是没有问题的, 只是其返回值 0.0
的 shape
不对, event_dim - transform.domain.event_dim=1
, 你返回一个 0.0
, TransformedDistribution
中的 log_prob
无法计算. 直接报错:
AttributeError: 'int' object has no attribute 'shape'
但是, 一旦把 0
改成 torch.tensor(0.0)
, _sum_rightmost(1)
不会对其有影响, 出来依然是标量, 能广播, 也没有加和的数值错误, 倒成了正确的.
[注]: 正交矩阵的行列式值都为 1 1 1.
4.3 代码的作者自己实现了 log_prob
既然两个转换的 log_abs_det
都有问题, 那为什么还取得了正确的测试结果? 答案是: 作者自己实现了 log_prob
的计算, 而并未使用 TransformedDistribution
中的 log_prob
.
def log_prob(self, value):
return self.log_normalizer() + self.scale * torch.log1p(
(self.loc * value).sum(-1)
)
假设我们注释掉这个作者实现的 log_prob
, 转而使用 TransformedDistribution
中的 log_prob
, 看看会有正确的结果不:
loc = torch.tensor([0.0, 1.0], requires_grad=True)
scale = torch.tensor(4.0, requires_grad=True)
dist = PowerSpherical(loc, scale)
step_size = 0.001
x = torch.arange(0, 2 * math.pi, step_size)
pt = torch.stack((torch.cos(x), torch.sin(x))).t()
y = torch.exp(dist.log_prob(pt)).detach()
print('integal:', y.sum() * step_size)
###################### output #######################
integal: tensor(inf)
竟然没有报错, 只是输出了一个 tensor(inf)
. 经过检查, 发现忽略了各 Distribution
的 event_shape
, 作者竟然都设置成了默认, 即 torch.Size([])
, 且各 Transform
的 forward_shape
和 inverse_shape
都保持默认, 那么送进 TransformedDistribution
后, base_distribution
不会被 expand
, 且 event_shape=torch.Size([])
, 进而 event_dim - transform.domain.event_dim = 0 -0 = 0
, _sum_rightmost
从来都不会计算.
即使 _HouseholderRotationTransform.log_abs_det_jacobian
返回值为 0
, 经过:
log_prob = log_prob - _sum_rightmost(
transform.log_abs_det_jacobian(x, y),
event_dim - transform.domain.event_dim,
)
链条的广播计算, 也不是问题. 可以正确计算了? 那为什么是 inf
? 问题就在于边缘变量 t
的 log_prb
的计算, 当
t
=
1
t=1
t=1 (
x
=
μ
\bm{x}=\bm{\mu}
x=μ) 时, Beta
分布的 log_prob(1)=log(0)
, 从而出现无穷的情况, 而这本来能乘以该处的均匀子球概率密度
1
0
\frac{1}{0}
01 避免的.
4.4 心得与教训
可以看到, 这种 BaseDistribution
+ Transform
实现复杂分布的方式提供方便的同时, 也会存在很多问题. 这里总结几点心得与教训:
- 当基础分布和变换都已经存在时, 使用此架构可以快速搭建新的概率分布, 不必考虑太多细节, 比如
log_pdf
,cdf
以及一些基础属性; - 当需要实现一个复杂分布时, 尽可能拆解成简单分布和变换, 注意要朝着已存在变换的方向分解, 哪怕拆出来的基础分布不在 PyTorch 的包内, 只要能使分布更简单, 就能达到简化的目的;
- 尽量不要自己写
Transform
, 因为要实现的东西有点多,_inverse
和log_abs_det_jacobian
都比较麻烦, 有那个功夫, 也已经把复杂分布的pdf
算出来了; - 尽量直接继承
Distribution
实现自己的分布, 为不是使用转换链. 因为转换意味着要计算log_abs_det_jacobian
, 然后沿着转换链累加, 这比直接计算目标分布的log_prob
增加了许多计算, 且承担不稳定的风险. 比如此例子中的接连三个变换, 计算复杂不说, 还拆开了 x = μ \bm{x} = \bm{\mu} x=μ 时的 f ( t = 1 ) = 0 f(t=1) = 0 f(t=1)=0 和 f ( 1 − t 2 v ) = 1 S p − 2 = 1 0 f(\sqrt{1-t^2}\bm{v}) = \frac{1}{S_{p-2}} = \frac{1}{0} f(1−t2v)=Sp−21=01, 加之为了避免 0 0 0 作为分母而加上的 e p s eps eps 偏移, 使得计算失败. - 如果只是为了采样, 则不必考虑
_inverse
和log_abs_det_jacobian
, 因为它们只出现在log_prob
中.
4.5 关于 batch_shape
, event_shape
和 Constraint.event_dim
.
如果迫不得已需要使用转换链, 我还是建议尽量实事求是地把这三个参数正确地写上, 而不是全部保持为 0. 一则不符合代码逻辑, 二来不定在哪就出错了. 包的作者既然这么设置, 肯定有他的道理.
附录
1. __debug__
和 assert
(来自 Kimi)
__debug__
是一个内置变量,用于指示 Python 解释器是否处于调试模式。当 Python 以调试模式运行时,__debug__
被设置为 True
;否则,在优化模式下运行时,它被设置为 False
。
__debug__
可以用于条件性地执行调试代码,例如:
if __debug__:
print("Debug mode is on, performing extra checks...")
# 这里可以放一些只在调试模式下运行的代码,比如详细的日志记录
# 或者复杂的验证逻辑
else:
print("Debug mode is off.")
在上面的例子中,如果命令行执行:
python -O myscript.py
##### output #####
Debug mode is off.
------------------------------------------------------
python myscript.py
##### output #####
Debug mode is on, performing extra checks...
assert
语句受 __debug__
影响:
def calculate(a, b):
# 这个 assert 在 __debug__ 为 True 时执行
assert a > 0 and b > 0, "Both inputs must be positive."
# 正常的函数逻辑
return a * b
# 在这里,assert 会检查输入是否为正数
result = calculate(5, 3)
print(result)
# 如果我们改变条件使 assert 失败
# result = calculate(-1, 3) # 这会触发 AssertionError,除非运行时 __debug__ 为 False
2. t t t 的概率密度函数推导
直接将 t = μ ⊺ x t = \bm{\mu}^\intercal\bm{x} t=μ⊺x 代入 f p ( x ; μ , κ ) f_p(\bm{x}; \bm{\mu}, \kappa) fp(x;μ,κ), 得: f p ( x ; μ , κ ) = C p ( κ ) ( 1 + μ ⊺ x ) κ = C p ( κ ) ( 1 + t ) κ t ∈ [ − 1 , 1 ] = C p ( κ ) ( 1 + c o s θ ) κ θ ∈ [ 0 , π ] \begin{aligned} f_p(\bm{x}; \bm{\mu}, \kappa) &= C_p(\kappa) (1 + \bm{\mu}^\intercal\bm{x})^\kappa & \\ &= C_p(\kappa) (1 + t)^\kappa & t \in [-1, 1] \\ &= C_p(\kappa) (1 + cos\theta)^\kappa & \theta \in [0, \pi] \end{aligned} fp(x;μ,κ)=Cp(κ)(1+μ⊺x)κ=Cp(κ)(1+t)κ=Cp(κ)(1+cosθ)κt∈[−1,1]θ∈[0,π] 注意这是 x \bm{x} x 一点处的概率密度. 沿着 t t t 处的切子球求积分, 以得到 t t t 或 θ \theta θ 处的整个概率密度: ∫ 切子球 f p ( x ; μ , κ ) d s = ∫ 切子球 C p ( κ ) ( 1 + μ ⊺ x ) κ d s = C p ( κ ) ( 1 + t ) κ 2 π p − 1 2 Γ ( p − 1 2 ) ( 1 − t 2 ) p − 2 2 S p − 2 的表面积 ∝ r p − 2 = C p ( κ ) ( 1 + c o s θ ) κ 2 π p − 1 2 Γ ( p − 1 2 ) s i n p − 2 θ \begin{aligned} \int_{切子球} f_p(\bm{x}; \bm{\mu}, \kappa) ds &= \int_{切子球} C_p(\kappa) (1 + \bm{\mu}^\intercal\bm{x})^\kappa ds & \\ &= C_p(\kappa) (1 + t)^\kappa\frac{2\pi^{\frac{p-1}{2}}}{\Gamma(\frac{p-1}{2})}(1-t^2)^\frac{p-2}{2} & S^{p-2} 的表面积 \propto r^{p-2} \\ &= C_p(\kappa) (1 + cos\theta)^\kappa \frac{2\pi^{\frac{p-1}{2}}}{\Gamma(\frac{p-1}{2})} sin^{p-2}\theta & \end{aligned} ∫切子球fp(x;μ,κ)ds=∫切子球Cp(κ)(1+μ⊺x)κds=Cp(κ)(1+t)κΓ(2p−1)2π2p−1(1−t2)2p−2=Cp(κ)(1+cosθ)κΓ(2p−1)2π2p−1sinp−2θSp−2的表面积∝rp−2 根据 n-sphere - Wikipedia, 切子球 S p − 2 S^{p-2} Sp−2 的表面积 S p − 2 = 2 π p − 1 2 Γ ( p − 1 2 ) r p − 2 S_{p-2} = \frac{2\pi^{\frac{p-1}{2}}}{\Gamma(\frac{p-1}{2})} r^{p-2} Sp−2=Γ(2p−1)2π2p−1rp−2, 再沿 t t t 或 θ \theta θ 积分: ∫ 0 π C p ( κ ) ( 1 + c o s θ ) κ 2 π p − 1 2 Γ ( p − 1 2 ) s i n p − 2 θ d θ = C p ( κ ) 2 π p − 1 2 Γ ( p − 1 2 ) ∫ 1 − 1 ( 1 + t ) κ ( 1 − t 2 ) p − 2 2 ( − 1 1 − t 2 d t ) ∵ c o s 0 = 1 , c o s π = − 1 = C p ( κ ) 2 π p − 1 2 Γ ( p − 1 2 ) ∫ − 1 1 ( 1 + t ) κ ( 1 − t 2 ) p − 3 2 d t = C p ( κ ) 2 π p − 1 2 Γ ( p − 1 2 ) ∫ − 1 1 ( 1 + t ) κ [ ( 1 + t ) ( 1 − t ) ] p − 3 2 d t = C p ( κ ) 2 π p − 1 2 Γ ( p − 1 2 ) ∫ − 1 1 ( 1 + t ) p − 1 2 + κ − 1 ( 1 − t ) p − 1 2 − 1 d t = C p ( κ ) 2 π p − 1 2 Γ ( p − 1 2 ) ∫ 0 1 ( 2 z ) p − 1 2 + κ − 1 ( 2 ( 1 − z ) ) p − 1 2 − 1 2 d z t = 2 z − 1 = C p ( κ ) 2 p + κ − 1 π p − 1 2 Γ ( p − 1 2 ) ∫ 0 1 z p − 1 2 + κ − 1 ( 1 − z ) p − 1 2 − 1 d z \begin{aligned} & \int_0^\pi C_p(\kappa) (1 + cos\theta)^\kappa \frac{2\pi^{\frac{p-1}{2}}}{\Gamma(\frac{p-1}{2})} sin^{p-2}\theta d\theta \\ =& C_p(\kappa) \frac{2\pi^{\frac{p-1}{2}}}{\Gamma(\frac{p-1}{2})} \int_{1}^{-1} (1 + t)^\kappa (1-t^2)^\frac{p-2}{2} (\frac{-1}{\sqrt{1-t^2}} dt) & \because cos0=1,~ cos\pi=-1 \\ =& C_p(\kappa) \frac{2\pi^{\frac{p-1}{2}}}{\Gamma(\frac{p-1}{2})} \int_{-1}^{1} (1 + t)^\kappa (1-t^2)^{\frac{p-3}{2}} dt \\ =& C_p(\kappa) \frac{2\pi^{\frac{p-1}{2}}}{\Gamma(\frac{p-1}{2})} \int_{-1}^{1} (1 + t)^\kappa [(1+t)(1-t)]^{\frac{p-3}{2}} dt \\ =& C_p(\kappa) \frac{2\pi^{\frac{p-1}{2}}}{\Gamma(\frac{p-1}{2})} \int_{-1}^{1} (1 + t)^{\frac{p-1}{2}+\kappa-1} (1-t)^{\frac{p-1}{2}-1} dt \\ =& C_p(\kappa) \frac{2\pi^{\frac{p-1}{2}}}{\Gamma(\frac{p-1}{2})} \int_{0}^{1} (2z)^{\frac{p-1}{2}+\kappa-1} (2(1-z))^{\frac{p-1}{2}-1} 2dz & t=2z-1 \\ =& C_p(\kappa) \frac{2^{p+\kappa-1}\pi^{\frac{p-1}{2}}}{\Gamma(\frac{p-1}{2})} \int_{0}^{1} z^{\frac{p-1}{2}+\kappa-1} (1-z)^{\frac{p-1}{2}-1} dz \end{aligned} ======∫0πCp(κ)(1+cosθ)κΓ(2p−1)2π2p−1sinp−2θdθCp(κ)Γ(2p−1)2π2p−1∫1−1(1+t)κ(1−t2)2p−2(1−t2−1dt)Cp(κ)Γ(2p−1)2π2p−1∫−11(1+t)κ(1−t2)2p−3dtCp(κ)Γ(2p−1)2π2p−1∫−11(1+t)κ[(1+t)(1−t)]2p−3dtCp(κ)Γ(2p−1)2π2p−1∫−11(1+t)2p−1+κ−1(1−t)2p−1−1dtCp(κ)Γ(2p−1)2π2p−1∫01(2z)2p−1+κ−1(2(1−z))2p−1−12dzCp(κ)Γ(2p−1)2p+κ−1π2p−1∫01z2p−1+κ−1(1−z)2p−1−1dz∵cos0=1, cosπ=−1t=2z−1 那么, 将 α = p − 1 2 + κ β = p − 1 2 C p ( κ ) 2 α + β π β Γ ( β ) = Γ ( α + β ) Γ ( α ) Γ ( β ) C p ( κ ) = Γ ( α + β ) 2 α + β π β Γ ( α ) \begin{aligned} \alpha =& \frac{p-1}{2}+\kappa \\ \beta =& \frac{p-1}{2} \\ C_p(\kappa) \frac{2^{\alpha+\beta}\pi^{\beta}}{\Gamma(\beta)} =& \frac{\Gamma(\alpha+\beta)}{\Gamma(\alpha)\Gamma(\beta)} \\ C_p(\kappa) =& \frac{\Gamma(\alpha+\beta)}{2^{\alpha+\beta}\pi^{\beta}\Gamma(\alpha)} \end{aligned} α=β=Cp(κ)Γ(β)2α+βπβ=Cp(κ)=2p−1+κ2p−1Γ(α)Γ(β)Γ(α+β)2α+βπβΓ(α)Γ(α+β)
3. ComposeTransform
的 domain.event_dim
def domain(self):
if not self.parts:
return constraints.real
domain = self.parts[0].domain
# Adjust event_dim to be maximum among all parts.
event_dim = self.parts[-1].codomain.event_dim
for part in reversed(self.parts):
event_dim += part.domain.event_dim - part.codomain.event_dim # 一个 Transform 的 event_dim 降了多少, 加回来
# 如果当前 Transform 的 codomain.event_dim 等于下一个 Transform 的 domain.event_dim, 那么下面两者一定相等;
event_dim = max(event_dim, part.domain.event_dim) # Transform 前后口对接不一致才会出现差别, 且大概率是前一个出口大于后一个入口 ->
# -> 假设现在 event_dim 等于后一个 domain.event_dim(事实上刚开始是这样的), 那么 event_dim - part.codomain.event_dim < 0 ->
assert event_dim >= domain.event_dim # -> 意味着 event_dim += part.domain.event_dim - part.codomain.event_dim < part.domain.event_dim
# 从而 event_dim = max(event_dim, part.domain.event_dim) = part.domain.event_dim, 还是和接口一致相同的结果.
# 所以, 只有前一个出口小于后一个入口时, 才会保留计算得到的 event_dim, 而不是更新为 part.domain.event_dim
# ===> 下面一句是结论:
if event_dim > domain.event_dim: # 这个条件正是说明链条中有"前一个出口小于后一个入口"的接口不兼容情况, 故而扩展开头的 domain, 以免出错; ->
domain = constraints.independent(domain, event_dim - domain.event_dim) # 目的是留够 dim, 足够后面的"小->大"消耗不兼容性.
return domain
4. TTransform
的 log_abs_det_jacobian
经过试验, 发现的确是 ((x.shape[-1] - 3) / 2) * torch.log(torch.clamp(1 - t ** 2, _EPS))
对:
class MarginalTDistribution(distributions.Distribution):
arg_constraints = {'kappa': constraints.positive}
has_rsample = True
_EPS = 1e-36
def __init__(self, m, kappa, dtype=None, validate_args=None):
self.m = m
self.kappa = kappa
super().__init__(batch_shape=kappa.shape[:-1], event_shape=kappa.shape[-1:], validate_args=validate_args)
self._dtype = dtype if dtype is not None else kappa.dtype
self._device = kappa.device
# >>> for sampling algorithm >>>
self._uniform = distributions.Uniform(self._EPS, torch.tensor(1.0 - self._EPS, dtype=self._dtype, device=self._device))
self._beta = distributions.Beta(
torch.tensor((self.m - 1) / 2, dtype=self._dtype, device=self._device),
torch.tensor((self.m - 1) / 2, dtype=self._dtype, device=self._device)
)
def sample(self, shape: Union[torch.Size, int] = torch.Size()):
if isinstance(shape, int):
shape = torch.Size([shape])
with torch.no_grad(): # rsample 是 reparameterized sample, 便于梯度更新以调整分布参数
return self.rsample(shape)
def rsample(self, shape=torch.Size()):
"""
Reparameterized Sample: 从一个简单的分布通过一个参数化变换使得其满足一个更复杂的分布;
此处, mu 是可变参数, 通过 radial-tangential decomposition 采样;
梯度下降更新 mu, 以获得满足要求的 vMF.
:param shape: 样本的形状
:return: [shape|m] 的张量, shape 个 m 维方向向量
"""
# shape = torch.Size(shape + self._batch_shape + self._event_shape)
w = (
self._sample_w3(shape=shape)
if self.m == 3
else self._sample_w_rej(shape=shape)
)
return w
def _sample_w3(self, shape: torch.Size):
"""
拒绝采样方法来自: Computer Generation of Distributions on the M-Sphere
https://rss.onlinelibrary.wiley.com/doi/abs/10.2307/2347441
:param shape: 采样 w 的的形状
:return: 形状为 shape 的张量, shape 个 w
"""
# kappa 也是有 shape 的, 说明可以并行多个 κ 吗?
shape = torch.Size(shape + self.kappa.shape) # torch.Size 继承自 tuple, 其 + 运算就是连接操作
# https://en.wikipedia.org/wiki/Von_Mises%E2%80%93Fisher_distribution # 3-D sphere
u = self._uniform.sample(shape)
w = 1 + torch.stack( # 这个公式是按 μ=(0,0,1) 计算的 w, arccosw=φ, 即 w=z
[ # 最后的旋转可能是旋转至按真正的 μ 采样结果
torch.log(u),
torch.log(1 - u) - 2 * self.kappa
],
dim=0
).logsumexp(0) / self.kappa
return w
def _sample_w_rej(self, shape: torch.Size):
num_need = math.prod(shape)
num_kappa = math.prod(self.kappa.shape)
sample_shape = torch.Size([num_kappa, 10 + math.ceil(num_need * 1.3)])
kappa = self.kappa.reshape(-1, 1)
c = torch.sqrt(4 * kappa.square() + (self.m - 1) ** 2)
b = (-2 * kappa + c) / (self.m - 1)
t_0 = (-(self.m - 1) + c) / (2 * kappa)
s = kappa * t_0 + (self.m - 1) * torch.log(1 - t_0.square())
# 大天坑, [[]] * num_kappa 虽然也形成了 [num_kappa]个[]], 但实际上这些 [] 是同一个对象;
# 导致后面的 ts[i] 全是同一个 [], 并 append 了所有不同 kappa 对应的 t 样本;
# [torch.cat(t)[:num_need] for t in ts] 中每个 tensor 也仅仅获取了第一个 kappa 对应的 t 样本;
# 因为 Python 的 * 运算符是浅拷贝, 而不是深拷贝.
ts = [[] for _ in range(num_kappa)] # * num_kappa
cnts = torch.zeros(num_kappa, device=self._device)
while cnts.lt(num_need).any():
y = self._beta.sample(sample_shape)
u = self._uniform.sample(sample_shape)
t = (1 - (1 + b) * y) / (1 - (1 - b) * y)
mask = (kappa * t + (self.m - 1) * torch.log(1 - t_0 * t) - s) > torch.log(u)
for i in range(num_kappa):
ts[i].append(t[i][mask[i]])
cnts += mask.sum(dim=-1)
samples = torch.stack([torch.cat(t)[:num_need] for t in ts], dim=1)
return samples.reshape(shape + self.kappa.shape)
def log_prob(self, value: torch.Tensor) -> torch.Tensor:
result = (self.m / 2 - 1) * torch.log(self.kappa / 2)
result -= math.lgamma(0.5) + math.lgamma(self.m / 2 - 0.5) + torch.log(ive(self.m / 2 - 1, self.kappa)) + self.kappa
result += value * self.kappa + (self.m / 2 - 1 - 0.5) * torch.log(1 - value.square())
return result
class _JointTSDistribution(distributions.Distribution):
arg_constraints = {}
def __init__(self, marginal_t, marginal_s):
super().__init__(batch_shape=marginal_s.batch_shape, event_shape=torch.Size([marginal_s._dim + 1]), validate_args=False)
self.marginal_t, self.marginal_s = marginal_t, marginal_s
def sample(self, sample_shape=()):
with torch.no_grad():
return self.rsample(sample_shape)
def rsample(self, sample_shape=()):
return torch.cat(
[
self.marginal_t.rsample(sample_shape).unsqueeze(-1),
self.marginal_s.sample(sample_shape + self.marginal_t.kappa.shape),
],
-1
)
def log_prob(self, value):
return self.marginal_t.log_prob(value[..., 0]) + self.marginal_s.log_prob(value[..., 1:])
class VonMisesFisher(distributions.TransformedDistribution):
arg_constraints = {
'loc': constraints.real_vector,
'scale': constraints.positive,
}
has_rsample = True
def __init__(self, loc, scale, validate_args=None):
self.loc, self.scale, = loc, scale
super().__init__(
_JointTSDistribution(
MarginalTDistribution(
loc.shape[-1], scale, validate_args=validate_args
),
HypersphericalUniform(
loc.shape[-1] - 1,
batch_shape=loc.shape[:-1],
validate_args=validate_args,
),
),
[TTransform(), HouseholderRotationTransform(loc)],
)
def log_prob_true(self, x):
if self._validate_args:
self._validate_sample(x)
return self._log_unnormalized_prob(x) + self._log_normalization()
def _log_unnormalized_prob(self, x): # k<μ,x>
kappa = self.base_dist.marginal_t.kappa
mu = self.loc
return kappa * (mu * x).sum(-1, keepdim=True)
def _log_normalization(self): # logCp(kappa)
m = self.loc.shape[-1]
kappa = self.base_dist.marginal_t.kappa
return (
(m / 2 - 1) * torch.log(kappa)
- (m / 2) * math.log(2 * math.pi)
- torch.log(ive(m / 2 - 1, kappa)) - kappa
)
if __name__ == '__main__':
vmf = VonMisesFisher(F.normalize(torch.randn(8), dim=-1), torch.tensor([3.0]))
x = vmf.sample(torch.Size([2]))
print(x)
print(vmf.log_prob(x))
print(vmf.log_prob_true(x))
from scipy import stats
svmf = stats.vonmises_fisher(vmf.loc.numpy(), vmf.scale.numpy().item())
print(svmf.logpdf(x.detach().numpy()))
但我始终想不明白为什么.
According to TTransform
, we have
[
t
,
1
−
t
2
v
1
,
⋯
,
1
−
t
2
v
m
−
1
]
=
t
t
r
a
n
s
f
o
r
m
(
[
t
,
v
1
,
⋯
,
v
m
−
1
]
)
[t, \sqrt{1-t^2}v_1, \cdots, \sqrt{1-t^2}v_{m-1}] = ttransform([t, v_1, \cdots, v_{m-1}])
[t,1−t2v1,⋯,1−t2vm−1]=ttransform([t,v1,⋯,vm−1]). Then the Jacobian matrix:
J
=
[
1
0
⋯
0
−
t
v
1
1
−
t
2
1
−
t
2
⋮
0
⋮
⋮
⋱
⋮
−
t
v
m
−
1
1
−
t
2
0
⋯
1
−
t
2
]
J = \begin{bmatrix} 1 & 0 & \cdots & 0 \\ \frac{-tv_1}{\sqrt{1-t^2}} & \sqrt{1-t^2} & \vdots & 0 \\ \vdots & \vdots & \ddots & \vdots \\ \frac{-tv_{m-1}}{\sqrt{1-t^2}}& 0 & \cdots & \sqrt{1-t^2} \end{bmatrix}
J=
11−t2−tv1⋮1−t2−tvm−101−t2⋮0⋯⋮⋱⋯00⋮1−t2
⇒
d
e
t
(
J
)
=
(
1
−
t
2
)
m
−
1
2
\Rightarrow det(J) = (1-t^2)^\frac{m-1}{2}
⇒det(J)=(1−t2)2m−1, I get the result
l
o
g
∣
J
∣
=
m
−
1
2
l
o
g
(
1
−
t
2
)
log|J| = \frac{m-1}{2}log(1-t^2)
log∣J∣=2m−1log(1−t2).