深入学习 torch.distributions

0. 引言

前几天分几篇博文精细地讲述了《von Mises-Fisher 分布》, 以及相应的 PyTorch 实现《von Mises-Fisher Distribution (代码解析)》, 其中以 Uniform 分布为例简要介绍了 torch.distributions 包的用法. 本以为已经可以了, 但这两天看到论文 The Power Spherical distribution代码, 又被其实现分布的方式所吸引.

Power Spherical 分布与 von Mises Fisher 分布类似, 只不过将后者概率密度函数中的指数函数换成了多项式函数: f p ( x ; μ , κ ) ∝ e x p ( κ μ ⊺ x ) ⇓ f p ( x ; μ , κ ) ∝ ( 1 + μ ⊺ x ) κ \begin{aligned} f_p(\bm{x}; \bm{\mu}, \kappa) &\propto exp(\kappa \bm{\mu}^\intercal \bm{x}) \\ &\Downarrow\\ f_p(\bm{x}; \bm{\mu}, \kappa) &\propto (1+\bm{\mu}^\intercal \bm{x})^\kappa \\ \end{aligned} fp(x;μ,κ)fp(x;μ,κ)exp(κμx)(1+μx)κ 采样框架基本一致, 且这么做可以使边缘 t t t 的线性变换 t + 1 2 ∼ B e t a ( p − 1 2 + κ , p − 1 2 ) \frac{t+1}{2} \sim Beta(\frac{p-1}{2}+\kappa, \frac{p-1}{2}) 2t+1Beta(2p1+κ,2p1), 从而避免了接受-拒绝采样过程.

当然, 按照之前的 VonMisesFisher 的写法, 这个 t 的采样大概是这样:

z = beta.sample(sample_shape)
t = 2 * z - 1

但现在我遇到了这种写法:

class MarginalTDistribution(tds.TransformedDistribution):
	arg_constraints = {
		'dim': constraints.positive_integer,
		'scale': constraints.positive,
	}
	has_rsample = True

	def __init__(self, dim, scale, validate_args=None):
		self.dim = dim
		self.scale = scale
		super().__init__(
			tds.Beta(  # 用 Beta 分布转换, z 服从 Beta(α+κ,β)
				(dim - 1) / 2 + scale, (dim - 1) / 2, validate_args=validate_args
			),
			transforms=tds.AffineTransform(loc=-1, scale=2),  # t=2z-1 是想要的边缘分布随机数
		)

然后就可以进行对 t t t 的采样了.

架构大概是这样的: 一个基本分布类 distributions.Beta 和一个转换 transforms.AffineTransform, 输入到 TransformedDistribution 的子类 MarginalTDistribution 中, 通过对一个 B e t a Beta Beta 的线性转换, 实现边缘分布 t t t.

图1: 上述代码的解构图. 浅蓝色代表抽象基类, 绿色代表实类; 虚线代表继承, 实线代表参数输入

我们可以看到其基本架构, 本文将详细解析其内部的具体细节, 包括:

1. Distribution

在之前的 <von Mises-Fisher Distribution (代码解析)> 中, 已经通过 Uniform 简单介绍了 Distribution 的用法. 它是实现各种分布的抽象基类. 本文将以解析源码的方式详细介绍.

1.1 参数验证 validate_args

打开源码, 首先映入眼帘的是关于参数验证的代码:

# true if Python was not started with an -O option. See also the assert statement.
_validate_args = __debug__

@staticmethod
def set_default_validate_args(value: bool) -> None:
	"""
	设置 validation 是否开启.
	validation 通常是耗时的, 所以最好在模型 work 后关闭它.
	"""
	if value not in [True, False]:
		raise ValueError
	Distribution._validate_args = value

Distribution 有一个类属性叫 _validate_args, 默认值是 __debug__(见附录1), 可以通过类静态方法 set_default_validate_args(value: bool) 来修改此值.

构造方法 __init__(...) 中的验证逻辑:

def __init__(self, ..., validate_args: Optional[bool]=None):
	...
	if validate_args is not None:
		self._validate_args = validate_args

也就是说, 你可以在创建 Distribution 实例的时候设置是否进行参数验证. 如果不设置, 则按照类的属性 Distribution._validate_args.

if self._validate_args:  # validate_args=False 就不用设置 arg_constraints 了
	try:  # 尝试获取字典 arg_constraints
		arg_constraints = self.arg_constraints
	except NotImplementedError:  # 如果没设置, 则设置为 {}, 抛出警告
		arg_constraints = {}
		warnings.warn(...)

如果需要验证参数, 那么首先要获取一个叫 arg_constraints 的参数验证字典, 它列出了需要验证哪些参数. 这个抽象类里面并没有给出, 需要用户继承该类时写在子类中. 以 Uniform 为例:

class Uniform(Distribution):
	...
	arg_constraints = {
		"low": constraints.dependent(is_discrete=False, event_dim=0),
		"high": constraints.dependent(is_discrete=False, event_dim=0),
	}
	...

至于 constraints.dependent 是啥, 后面会详细介绍. 值得注意的是, 如果你在创建实例时指定 validate_args=False, 那么所有关于参数验证的事就都不用管了.

for param, constraint in arg_constraints.items():
	if constraints.is_dependent(constraint):
		continue  # skip constraints that cannot be checked
	if param not in self.__dict__ and isinstance(
			getattr(type(self), param), lazy_property
	):
		continue  # skip checking lazily-constructed args
	value = getattr(self, param)  # 从当前对象获取参数 value
	valid = constraint.check(value)  # 检查参数值
	if not valid.all():  # 检查不通过
		raise ValueError(...)

这一段就是验证过程了, 包括:

  • skip constraints that cannot be checked, 由 constraints.is_dependent(constraint) 判断是否可验证;
  • skip checking lazily-constructed args, 即参数名不在 self.__dict__ 中, 并属于 lazy_property 的跳过;
  • 获得参数, 进行验证;

具体的验证细节将在后面介绍.

1.2 batch_shape & event_shape

除了 validate_args 参数, __init__(...) 方法中的另外两个参数就是:

def __init__(
		self,
		batch_shape: torch.Size = torch.Size(),
		event_shape: torch.Size = torch.Size(),
):
	self._batch_shape = batch_shape
	self._event_shape = event_shape
	...

这两个参数是啥? 在这个抽象类中, 我们看不到太多信息, 甚至 Uniform 中也只有 batch_shape = self.low.size() 的信息, 大概意思同时进行着一批的均匀分布, 如 low = torch.tensor([0.0, 1.0]) 时, batch_shape = torch.Size([2]), 表示一个二元的均匀分布. 看 MultivariateNormal, 里面信息量较大:

batch_shape = torch.broadcast_shapes(
	covariance_matrix.shape[:-2],  # [:-2]是去掉了协方差矩阵的维度, 剩下的可能是 batch 的维度
	loc.shape[:-1]  # [:-1]是去掉了 envent 的维度, 剩下的可能是 batch 的维度
)  # broadcast_shapes 意思是进行了广播, 如果 matrix 的 batch_shape 是 [2,1], loc 的 batch_shape 是 [1,2], 那么整个的 batch_shape 是广播后的 [2,2]
self.covariance_matrix = covariance_matrix.expand(batch_shape + (-1, -1))  # 之后 covariance_matrix 都被 expand 了
...
event_shape = self.loc.shape[-1:]  # 看来就是样本的 shape

从这一段来看, batch_shape 是指创建的实例在进行多少个平行的基本分布, 而 event_shape 是指基本分布的事件(支撑点)维度. 如:

locs = torch.randn(2, 3)
matrixs = torch.randn(2, 3, 3)
covariance_matrixs = torch.bmm(matrixs, matrixs.transpose(1, 2))

normal = distributions.MultivariateNormal(loc=locs, covariance_matrix=covariance_matrixs)
print(normal.batch_shape)  # 2
print(normal.event_shape)  # 3
print(normal.sample())
##### output #####
torch.Size([2])
torch.Size([3])
tensor([[ 1.8972, -0.3961, -0.1530],
		[-0.5018, -2.5110,  0.1293]])

batch 的意思还是那个 batch, 不过这里是指分布的 batch, 而不是数据的 batch. 采样时, 得到一批 samples, 对应每个分布.

还有一个 method 和这两个参数有关: expand, 因为它是一个抽象 method, 基类中并没有实现, 那就直接看 MultivariateNormal 中的:

def expand(self, batch_shape: torch.Size, _instance=None):
	"""
	Args:
		batch_shape (torch.Size): the desired expanded size.
		_instance: new instance provided by subclasses that need to override `.expand`.
	Returns:
		New distribution instance with batch dimensions expanded to `batch_size`.
	"""
	new = self._get_checked_instance(MultivariateNormal, _instance)
	batch_shape = torch.Size(batch_shape)
	loc_shape = batch_shape + self.event_shape
	cov_shape = batch_shape + self.event_shape + self.event_shape
	new.loc = self.loc.expand(loc_shape)
	new._unbroadcasted_scale_tril = self._unbroadcasted_scale_tril
	if "covariance_matrix" in self.__dict__:
		new.covariance_matrix = self.covariance_matrix.expand(cov_shape)
	if "scale_tril" in self.__dict__:
		new.scale_tril = self.scale_tril.expand(cov_shape)
	if "precision_matrix" in self.__dict__:
		new.precision_matrix = self.precision_matrix.expand(cov_shape)
	super(MultivariateNormal, new).__init__(
		batch_shape, self.event_shape, validate_args=False
	)
	new._validate_args = self._validate_args
	return new

这个 method 会创建一个新的 instance 或调用的时候用户提供, 并设置 batch_shape 为参数提供的形状, 然后把参数 expand 到新的 batch_shape. 用法:

mean = torch.randn(3)
matrix = torch.randn(3, 3)
covariance_matrix = torch.mm(matrix, matrix.t())

mvn = MultivariateNormal(mean, covariance_matrix)
bmvn = mvn.expand(torch.Size([2]))

print(bmvn.batch_shape)
print(bmvn.event_shape)
print(bmvn.sample())

##### output #####
torch.Size([2])
torch.Size([3])
tensor([[-4.0891, -4.2424,  6.2574],
		[ 0.7656, -0.2199, -0.9836]])

[小结]: 关于 batch_shapeevent_shape 的意义

  • Distribution 允许创建一批同类型分布, batch_shape 是指"批的形状", 如一批 3MultivariateNormal 分布, 你需要提供 shape=[batch_shape|3] 的一批 3 维向量, 以及 shape=[batch_shape|3,3] 的一批 3x3 协方差矩阵.
  • event_shape 就是支撑点的 shape, 如 3MultivariateNormal 分布的支撑点的 shape=(3,).
1.3 一些属性

包括: m e a n mean mean, m o d e mode mode, s t d std std, v a r i a n c e variance variance, e n t r o p y entropy entropy 等基本属性, 都需要用户在子类中自己实现. 还有一些相关的函数:

  • cumulative density/mass function cdf(value);
  • inverse cumulative density/mass function icdf(value);
    这个函数非常有用, Inverse Transform Sampling 中用其进行采样. 从 U ( 0 , 1 ) U(0,1) U(0,1) 中采样一个 u u u, 然后令 x = F − 1 ( u ) x = F^{-1}(u) x=F1(u) 就是所求随机变量 X X X 的一个采样.
  • log of the probability density/mass function log_prob(value), 对数概率.

注意, 目前看到的只有 log_prob, 并没有 prob, 一些示例要么只算 log_prob, 要么计算后通过 exp(log_prob) 得到 prob.

2. constraints.Constraint

前面在1.1参数验证中已经遇到 constraints.dependent(is_discrete=False, event_dim=0)constraint.check(value), 但没有讲具体细节. 本节将详细剖析.

2.1 抽象基类 Constraint

先看源码:

class Constraint:
	"""
	一个 constraint 对象, 表示变量在某区域内有效, 即变量可优化的范围.
	"""

	is_discrete = False  # Default to continuous.
	event_dim = 0  # Default to univariate.

	def check(self, value):
		"""
		结果的形状为"sample_shape + batch_shape", 指示 each event 值是否满足此限制.
		"""
		raise NotImplementedError

这是抽象基类 Constraint, 比较简单, 只有两个类属性和一个 method check(value). is_discrete 表示待验证值是否为离散; 联想前面的 event_shape, 大概可以知道 event_dim 是指 len(event_shape).(不过目前看只是为了验证参数, 还能验证采样的 event?)

2.2 _Dependent() 不被验证

这个基类信息太少, 对我们理解前面的内容毫无用处, 还是直接观察一些子类吧. 从 dependent = _Dependent() 开始, 它是 constraints.py 中定义好的 placeholder(这个倒是可以学一学):

class _Dependent(Constraint):  # 看"_", 应该是不希望用户直接创建实例
	"""
	Placeholder for variables whose support depends on other variables.
	These variables obey no simple coordinate-wise constraints.
	"""

	def __init__(self, *, is_discrete=NotImplemented, event_dim=NotImplemented):
		self._is_discrete = is_discrete
		self._event_dim = event_dim
		super().__init__()

	def __call__(self, *, is_discrete=NotImplemented, event_dim=NotImplemented):
		"""
		Support for syntax to customize static attributes::
			constraints.dependent(is_discrete=True, event_dim=1)
		"""
		if is_discrete is NotImplemented:  # 未提供就是默认
			is_discrete = self._is_discrete
		if event_dim is NotImplemented:
			event_dim = self._event_dim
		return _Dependent(is_discrete=is_discrete, event_dim=event_dim)

	def check(self, x):
		raise ValueError("Cannot determine validity of dependent constraint")

闹了半天, 我们并不能看到 constraints.dependent(is_discrete=False, event_dim=0) 有什么卵用, 只知道 “Cannot determine validity of dependent constraint”, 这也呼应了前面的:

if constraints.is_dependent(constraint):
	continue  # skip constraints that cannot be checked

也就是说, dependent 类型的限制是不会执行参数验证的. 那这个 _Dependent 到底有何用处? 先不管了.

2.3 _IndependentConstraint 重新解释 event_dim

我们看点复杂的, MultivariateNormal.arg_constraints:

arg_constraints = {
	"loc": constraints.real_vector,
	"covariance_matrix": constraints.positive_definite,
	"precision_matrix": constraints.positive_definite,
	"scale_tril": constraints.lower_cholesky,
}

这些都是 constraints.py 中定义好的实例, 对于大多情况, 这些预定义好的实例已经够用, 但如果需要, 你也可以自定义. 先看 real_vector:

independent = _IndependentConstraint
real_vector = independent(real, 1)
class _IndependentConstraint(Constraint):
	"""
	封装一个 constraint,  通过 aggregating over ``reinterpreted_batch_ndims``-many dims in :meth:`check`,
	an event is valid 当且仅当它依赖的所有 entries 是 valid 的.
	"""

	def __init__(self, base_constraint, reinterpreted_batch_ndims):
		self.base_constraint = base_constraint
		self.reinterpreted_batch_ndims = reinterpreted_batch_ndims
		super().__init__()

	@property
	def event_dim(self):
		# real.event_dim 是 0, + real_vector(reinterpreted_batch_ndims=1) = 1
		return self.base_constraint.event_dim + self.reinterpreted_batch_ndims

	def check(self, value):
		result = self.base_constraint.check(value)  # 首先要符合 base.check
		if result.dim() < self.reinterpreted_batch_ndims:
			# 给 batch 留够 dim
			expected = self.base_constraint.event_dim + self.reinterpreted_batch_ndims
			raise ValueError(
				f"Expected value.dim() >= {expected} but got {value.dim()}"
			)
		result = result.reshape(  # 减掉 event
			result.shape[: result.dim() - self.reinterpreted_batch_ndims] + (-1,)
		)
		result = result.all(-1)  # 减少一个 dim
		return result

意思很明了了, real_vector 是依赖于 real(base_constraint) 的, reinterpreted_batch_ndims=1 是说把原来 valuebatch_dim 重新解释, 分出 n 个给 event_dim: 加上 reinterpreted_batch_ndims, 比如

value = [[1, 2, 3],
		 [4, 5, 6]]

本来 realevent_dim=0, 验证结果为(sample_shape + batch_shape = (2,2)):

value = [[True, True, True],
		 [True, True, True]]

现在重新解释为 event_dim=1, 验证结果为:

result = result.reshape(  # 减掉 event
   	result.shape[: result.dim() - self.reinterpreted_batch_ndims] + (-1,)  # (-1,) 表示新 event 内的所有 entries 展平
)
result = result.all(-1)  # 新 event 内的所有 entries 为 True, 则新 event 为 True
================>
value = [True, True]

3. Transform & _InverseTransform

上一节介绍了 constraints.Constraint, 明白了在构建 Distribution 实例时进行的参数验证, 以保证用户提供的参数符合要求. 但还留下了一个疑问: Constraint 中的 event_dim 是指 len(event_shape), 难道还能验证采样的 event? 再者, check(value) 返回值的形状是 sample_shape + batch_shape, 进一步说明它是会被用于采样结果检查的. 让我们看一看能否在 Transform 中找到答案.

Transform & _InverseTransform 是一对互逆的操作, 实现从一个分布到另一个分布的转换. 这很有用, 因为 distributions 包已经实现了很多常见分布和转换, 自由组合威力巨大. 本节将详细介绍它是如何实现对分布的转换的.
[注] 从 _InverseTransform_ 来看, 是不需要用户了解它的.

3.1 抽象类 Transform 的基本信息
class Transform:
	"""
	变换的抽象基类, 子类应该实现 one or both of `_call` or `_inverse`.
	如果 `bijective=True`, 则必须实现 `log_abs_det_jacobian`.

	Args:
		cache_size (int): If one, the latest single value is cached.
		Only 0 and 1 are supported.
	"""
	bijective = False  # Transform 是否双射, 默认 False
	domain: constraints.Constraint  # 有效输入范围
	codomain: constraints.Constraint  # 有效输出范围

	def __init__(self, cache_size=0):
		self._cache_size = cache_size
		self._inv = None
		if cache_size == 0:
			pass  # default behavior
		elif cache_size == 1:
			self._cached_x_y = None, None
		else:
			raise ValueError("cache_size must be 0 or 1")
		super().__init__()

果然, Transform 中有 Constraint 的, 分别是 domaincodomain, 用于其检查输入输出是否符合要求. 此外, 还有 bijectivecache_size 这两个信息, 等一下看后面怎么说.

3.2 AffineTransform

抽象类的基本信息不多, 还是要看一个简单的例子: AffineTransform, 线性变换.

class AffineTransform(Transform):
	bijective = True

	def __init__(self, loc, scale, event_dim=0, cache_size=0):
		super().__init__(cache_size=cache_size)
		self.loc = loc
		self.scale = scale
		self._event_dim = event_dim

线性变换是可逆的, 可以看到它的 bijective = True. 参数是 y   =   l o c   +   s c a l e   ×   x y = loc + scale × x y=loc+scale×x 中的 locscale; event_dim 则是用于构建 domaincodomain:

@constraints.dependent_property(is_discrete=False)
def domain(self):
	if self.event_dim == 0:
		return constraints.real
	return constraints.independent(constraints.real, self.event_dim)

@constraints.dependent_property(is_discrete=False)
def codomain(self):
	if self.event_dim == 0:
		return constraints.real
	return constraints.independent(constraints.real, self.event_dim)

即, domaincodomain 被限制为 event_dim 维向量, 默认是 0, 输入输出皆为标量.

变换过程
def _call(self, x):
	"""
	Method to compute forward transformation.
	"""
	return self.loc + self.scale * x

def _inverse(self, y):
	"""
	Method to compute inverse transformation.
	"""
	return (y - self.loc) / self.scale

由于是双射, 还要实现:

def log_abs_det_jacobian(self, x, y):
	shape = x.shape
	scale = self.scale
	if isinstance(scale, numbers.Real):
		result = torch.full_like(x, math.log(abs(scale)))
	else:
		result = torch.abs(scale).log()
	if self.event_dim:
		result_size = result.size()[: -self.event_dim] + (-1,)
		result = result.view(result_size).sum(-1)
		shape = shape[: -self.event_dim]
	return result.expand(shape)

计算结果的形状调整为 x 中除 event_dim 以外的形状, 即 sample_shape + batch_shape. 至于为什么要这么做, 还需要看 TransformedDistribution 中具体的转换流程.

这里有个问题, 假设 event_dim=1, 输入的 x.shape=(2,3), 而 scale=2.0scale=torch.tensor(2.0)计算结果是不一致的:

====================== scale=2.0 ==========================
result = torch.full_like(x, math.log(abs(2.0)))
[[log(2), log(2), log(2)],
 [log(2), log(2), log(2)]]
result_size = (2,3)[: -1] + (-1,) = (2,3)
result = [3log(2), 3log(2)].expand([2]) = [3log(2), 3log(2)]
================== scale=tensor(2.0) =======================
result = torch.abs(scale).log() = log(2)
result_size = ()[: -1] + (-1,) = (-1,)
result = log(2).expand([2]) = [log(2), log(2)]

类似的, 只要 scaletensor, 并出现了计算广播, 就会出现这种情况. 不知道会不会造成计算错误, 看了后面的 TransformedDistribution 就能知道. 现在只能暂时不管了.

经测验, 果然是有问题的:

mn = distributions.MultivariateNormal(torch.tensor([0., 0., 0.]), torch.eye(3))
amn1 = distributions.TransformedDistribution(
	mn,
	[distributions.AffineTransform(torch.tensor([1., 2., 3.]), torch.tensor(2.0), event_dim=1)]
)
amn2 = distributions.TransformedDistribution(
	mn,
	[distributions.AffineTransform(torch.tensor([1., 2., 3.]), 2.0, event_dim=1)]
)
x = mn.sample(torch.Size([2]))
xs = x * 2.0 + torch.tensor([1., 2., 3.])
print(amn1.log_prob(xs))
print(amn2.log_prob(xs))

########## output ###########
tensor([0.6931, 0.6931])  # 输出的 log_abs_det_jacobian for scale=torch.tensor(2.0)
tensor([-4.6540, -4.7094])
tensor([2.0794, 2.0794])  # 输出的 log_abs_det_jacobian scale=2.0
tensor([-6.0403, -6.0957])

计算的概率密度也不一致, 理论上应该是后者对. 所以, 源代码编写者的原意可能是: 只要 scaletensor, 那么我就按你是broadcast.

3.3 TransformedDistribution
3.3.1 基本信息
class TransformedDistribution(Distribution):
	"""
	Extension of the Distribution class, which applies a sequence of Transforms
	to a base distribution.
	"""
	arg_constraints: Dict[str, constraints.Constraint] = {}

	def __init__(self, base_distribution, transforms, validate_args=None):
		>>> 单 transfrom 变成 [transfrom], 再检查是否符合 transforms: List[Transform] <<<

它是对 Distribution 的扩展, 对一个 base distribution 实施一连串的 Transforms:

X ~ BaseDistribution
Y = f(X) ~ TransformedDistribution(BaseDistribution, f)
log p(Y) = log p(X) + log |det (dX/dY)|

一个简单的例子:

# #################################
# Building a Logistic Distribution
# X ~ Uniform(0, 1)
# f = a + b * logit(X)
# Y ~ f(X) ~ Logistic(a, b)
# #################################
base_distribution = Uniform(0, 1)
transforms = [SigmoidTransform().inv, AffineTransform(loc=a, scale=b)]
logistic = TransformedDistribution(base_distribution, transforms)

其中 l o g i t ( x ) = l o g x 1 − x logit(x) = log\frac{x}{1-x} logit(x)=log1xx s i g m o i d sigmoid sigmoid 函数的逆.

下面是 TransformedDistribution__init__(...) 内容(省略了开头将单 Transform 转换为列表以及检查类型的代码):

# >>> Reshape base_distribution according to transforms. >>>
# >>> 获取 base_distribution 的 batch_shape 和 event_shape 以及 event_dim >>>
base_shape = base_distribution.batch_shape + base_distribution.event_shape
base_event_dim = len(base_distribution.event_shape)  # 的基本 shape
# <<< 获取 base_distribution 的 batch_shape 和 event_shape 以及 event_dim <<<
# 将 transforms 组合成一个 transform
transform = ComposeTransform(self.transforms)
# 先正向传播 shape, 再反向传播 shape, 一来一回 shape 不一致, 说明途中发生了广播
# 具体例子可为: 线性转换中的 [1,2,3] * [[2],[3]], 输入向量输出矩阵(再反向也是矩阵)
forward_shape = transform.forward_shape(base_shape)
expanded_base_shape = transform.inverse_shape(forward_shape)
if base_shape != expanded_base_shape:  # 不一致说明发生了广播 (AffineTransform为例)
	base_batch_shape = expanded_base_shape[
		: len(expanded_base_shape) - base_event_dim
	]  # 干脆先把 base_distribution 给 expand 了
	# 如 base_shape = batch_shape + event_shape = (,) + (,3) = (3,)
	# expanded_base_shape = (2,3), 则 base_batch_shape = (2,)
	base_distribution = base_distribution.expand(base_batch_shape)  # 结果 base_shape = (2,3)
# transform.domain.event_dim 是指所有 transforms 中最大的 domain.event_dim (这个 domain.event_dim 可能就只是为了检查 dim 是否够用)
reinterpreted_batch_ndims = transform.domain.event_dim - base_event_dim
if reinterpreted_batch_ndims > 0:
	base_distribution = Independent(  # 但却实实在在地调整了 base_distribution 的 event_dim
		base_distribution, reinterpreted_batch_ndims
	)  # 参考前面讲的 _IndependentConstraint
self.base_dist = base_distribution

这一部分的主旋律是 Reshape base_distribution according to transforms. 也就是说, self.base_dist 被赋予的是调整过的 base_distribution. 主要包括:

  • 调整 batch_shape, by base_distribution.expand(base_batch_shape), 前面讲过 expand;
  • 调整 event_shape, by Independent, 这个类似前面讲的 _IndependentConstraint, 只不过这里是对 Distribution 操作;

然而基类 Distribution 默认的 expand 未实现, 所以如果预期 Transforms 链中间会有接口不兼容的时候, 要注意实现 expand, 否则出错. 不过出错也好, 让你知道不兼容, 从而减少未知性.

具体过程看注释. 所以, 使用这种方式建立新的 Distribution 时, 要同时注意 base_distributiontransformsevent_dim, 这对 log_prob 的计算有影响, 且 base_distributionevent_dim 可能被更改.

安排好 self.base_dist 后, 开始计算本 TransformedDistributionbatch_shapeevent_shape.

# Compute shapes.
transform_change_in_event_dim = (  # transform 导致的 event_dim 变化
	transform.codomain.event_dim - transform.domain.event_dim
)
event_dim = max(
	transform.codomain.event_dim,  # the transform is coupled
	base_event_dim + transform_change_in_event_dim,  # the base dist is coupled
)
assert len(forward_shape) >= event_dim
cut = len(forward_shape) - event_dim  # forward_shape 劈开
batch_shape = forward_shape[:cut]
event_shape = forward_shape[cut:]
super().__init__(batch_shape, event_shape, validate_args=validate_args)
3.3.2 采样
def sample(self, sample_shape=torch.Size()):
	with torch.no_grad():
		x = self.base_dist.sample(sample_shape)
		for transform in self.transforms:
			x = transform(x)
		return x

def rsample(self, sample_shape=torch.Size()):
	x = self.base_dist.rsample(sample_shape)
	for transform in self.transforms:
		x = transform(x)
	return x
3.3.3 log_prob 需要 log_abs_det
def log_prob(self, value):
	if self._validate_args:  # 验证样本的就在此处了
		self._validate_sample(value)
	event_dim = len(self.event_shape)
	log_prob = 0.0
	y = value
	for transform in reversed(self.transforms):  # 倒着来
		x = transform.inv(y)  # 逆变换得到 x, 想计算 `log_prob`, 逆变换就得实现.
		event_dim += transform.domain.event_dim - transform.codomain.event_dim
		log_prob = log_prob - _sum_rightmost(
			transform.log_abs_det_jacobian(x, y),
			event_dim - transform.domain.event_dim,
		)
		y = x

	log_prob = log_prob + _sum_rightmost(
		self.base_dist.log_prob(y), event_dim - len(self.base_dist.event_shape)
	)
	return log_prob

根据 f y ( y ) = f X ( x ) / ∣ d y d x ∣ f_y(y) = f_X(x)/|\frac{dy}{dx}| fy(y)=fX(x)/∣dxdy, 比较容易理解 l o g f y ( y ) = l o g f X ( x ) − l o g ∣ d y d x ∣ logf_y(y) = logf_X(x) - log|\frac{dy}{dx}| logfy(y)=logfX(x)logdxdy, 那么代码中大概的框架是连续减 l o g ∣ d y d x ∣ log|\frac{dy}{dx}| logdxdy, 到这没什么问题. 问题就在于为什么要执行:

_sum_rightmost(
	transform.log_abs_det_jacobian(x, y),
	event_dim - transform.domain.event_dim,
)

event_dim - transform.domain.event_dim 个最右侧的维度加起来? 空想不好理解, 必须举个例子, 还是以 AffineTransform 为例:

event_dim > transform.domain.event_dim

假设有:

value = [[1, 2, 3],
		 [4, 5, 6]], event_dim=1
affine = AffineTransform(1, 2, event_dim=0)  # transform.domain.event_dim = 0

则计算的 log_abs_det 为:

[[log2, log2, log2],
 [log2, log2, log2]]

但按照 event_dim=1, 即基本的 event 单元为 [1, 2, 3][4, 5, 6], 对应的 l o g ∣ d y d x ∣ log|\frac{dy}{dx}| logdxdy[3log2, 3log2], 这就用到了上面说的 _sum_rightmost, 把 event_dim - transform.domain.event_dim 个最右侧的维度加起来.

而在在一连串的 transforms 中, event_dim += transform.domain.event_dim - transform.codomain.event_dim 代表着当前 xevent_dim.

event_dim = transform.domain.event_dim 时表示刚刚好, 输入符合 transform.codomain.event_dim 的要求. 有没有可能 event_dim < transform.domain.event_dim? 怕是不能!

4. 实战解析及解惑

在读了 torch.distributions 包的源码之后, 让我们回到论文 The Power Spherical distribution代码, 再一看则豁然开朗. 包括边缘变量 t t t均匀子球的采样 v \bm{v} v 的组合操作, 以及组合后的 Householder 变换. 详情请参阅《von Mises-Fisher 分布》.

4.1 边缘变量 t t t 和均匀子球的采样 v \bm{v} v 的组合操作
class _TTransform(tds.Transform):
	"""
	大概就是 cat(t,v) 的吧, 注意, t 在开头
	"""
	# 设置为向量还是有必要的, 因为传递过程中的 log_abs_det_jacobian 计算要用到 event_dim
	# real.event_dim=0, real_vector.event_dim=1
	# 关系到传播中 log_prob 的计算
	domain = constraints.real  # 输入空间是实数
	codomain = constraints.real  # 输出空间也是实数

	def _call(self, x):
		t = x[..., 0].unsqueeze(-1)
		v = x[..., 1:]
		return torch.cat([t, v * torch.sqrt(torch.clamp(1 - t ** 2, _EPS))], -1)

	def _inverse(self, y):
		t = y[..., 0].unsqueeze(-1)
		v = y[..., 1:]
		return torch.cat([t, v / torch.sqrt(torch.clamp(1 - t ** 2, _EPS))], -1)

	def log_abs_det_jacobian(self, x, y):
		"""
		计算变换后的分布的概率密度时有用 fY(y) = fX(x(y))|dx/dy|
		:param x: input
		:param y: output
		:return: the log det jacobian log |dy/dx| given input and output
		"""
		t = x[..., 0]
		# return ((x.shape[-1] - 3) / 2) * torch.log(torch.clamp(1 - t ** 2, _EPS))  # 怎么感觉是 (d-1)/2?
		return ((x.shape[-1] - 1) / 2) * torch.log(torch.clamp(1 - t ** 2, _EPS))

_call_inverse 都是对的, 但感觉 log_abs_det_jacobian 有问题. 首先我们从数学上先推导一下这个变换的 d y d x \frac{dy}{dx} dxdy. 设 [ t , 1 − t 2 v 1 , ⋯   , 1 − t 2 v m − 1 ] = t t r a n s f o r m ( [ t , v 1 , ⋯   , v m − 1 ] ) [t, \sqrt{1-t^2}v_1, \cdots, \sqrt{1-t^2}v_{m-1}] = ttransform([t, v_1, \cdots, v_{m-1}]) [t,1t2 v1,,1t2 vm1]=ttransform([t,v1,,vm1]), 其中 m m m 是向量的维度. 那么雅可比矩阵为: J = [ 1 0 ⋯ 0 − t v 1 1 − t 2 1 − t 2 ⋮ 0 ⋮ ⋮ ⋱ ⋮ − t v m − 1 1 − t 2 0 ⋯ 1 − t 2 ] J = \begin{bmatrix} 1 & 0 & \cdots & 0 \\ \frac{-tv_1}{\sqrt{1-t^2}} & \sqrt{1-t^2} & \vdots & 0 \\ \vdots & \vdots & \ddots & \vdots \\ \frac{-tv_{m-1}}{\sqrt{1-t^2}}& 0 & \cdots & \sqrt{1-t^2} \end{bmatrix} J= 11t2 tv11t2 tvm101t2 0001t2 d e t ( J ) = ( 1 − t 2 ) m − 1 2 det(J) = (1-t^2)^\frac{m-1}{2} det(J)=(1t2)2m1, 那么 l o g ∣ d y d x ∣ = m − 1 2 l o g ( 1 − t 2 ) log|\frac{dy}{dx}| = \frac{m-1}{2}log(1-t^2) logdxdy=2m1log(1t2). 我们看一看代码计算结果是什么, 值是刚才计算的值, 形状为 x.shape[:-1]. 那么问题来了, 如果不管 domain.event_dim, 这个计算结果还是对的, 但这里的 domain.event_dim=real.event_dim=0, 而根据 TransformedDistribution 中的 # Compute shapes, 可计算得到 self.event_dim=1, 此时 event_dim - transform.domain.event_dim=1, 你返回这个值后, 计算 log_prob 会再执行:

_sum_rightmost(
	transform.log_abs_det_jacobian(x, y),
	event_dim - transform.domain.event_dim
)

从而导致其在 sample_shape[-1] 上相加, 减少一个维度, 这肯定是不对的. 假设我们构造分布:

mn = distributions.MultivariateNormal(torch.zeros(2, 3), torch.eye(3).expand(2, 3, 3))
tmn = distributions.TransformedDistribution(mn, [_TTransform()])

x = tmn.sample(torch.Size([]))
print(x)
print(tmn.log_prob(x))

########## output ##########
tensor([[ 0.1690, -0.2441, -1.0599],
        [-0.2891,  0.4383,  2.4503]])
tensor([-3.2637, -6.0629])

哎! 看着没问题啊, 但实际上此时 _TTransform.log_abs_det_jacobian 得到的是 shape=(2,)tensor, 然后被 _sum_rightmost 减少了一维, 得到一个 scalar, 再加 MultivariateNormal.log_prob(shape=[2]) 时是可广播的, 故而不会报错(实际上已经计算错误). 我们找到 TransformedDistribution.log_prob, 添加一些 print 打印信息:

def log_prob(self, value):
	"""
	Scores the sample by inverting the transform(s) and computing the score
	using the score of the base distribution and the log abs det jacobian.
	"""
	if self._validate_args:
		self._validate_sample(value)
	event_dim = len(self.event_shape)
	log_prob = 0.0
	y = value
	for transform in reversed(self.transfrms):
		x = transform.inv(y)
		event_dim += transform.domain.event_dim - transform.codomain.event_dim
		print(event_dim)
		print(event_dim - transform.domain.event_dim)
		print(transform.log_abs_det_jacobian(x, y))
		log_prob = log_prob - _sum_rightmost(
			transform.log_abs_det_jacobian(x, y),
			event_dim - transform.domain.event_dim,
		)
		y = x
	print(log_prob)
	log_prob = log_prob + _sum_rightmost(
		self.base_dist.log_prob(y), event_dim - len(self.base_dist.event_shape)
	)
	return log_prob

输出:

tensor([[-3.0071e+00,  6.7510e-19, -3.1936e-18],
        [-3.2723e-01,  1.4745e+00,  7.2704e-01]])
event_dim: 1
event_dim - transform.domain.event_dim: 1
transform.log_abs_det_jacobian(x, y): tensor([-82.8931,  -0.1133])
log_prob: 83.0063247680664
tensor([70.4010, 78.6826])

果然不出所料, 在 sample_shape[-1] 上相加, 减少一个维度, 这肯定是不对的. 其实也不一定叫 sample_shape, 一般采样的 shape = sample_shape + batch_shape + event_shape, 是在 sample_shape + batch_shape 的最后一维执行 sum. 假设 shape = sample_shape([4]) + batch_shape([2]) + event_shape([3]):

mn = distributions.MultivariateNormal(torch.zeros(2, 3), torch.eye(3).expand(2, 3, 3))
tmn = distributions.TransformedDistribution(mn, [_TTransform()])

x = tmn.sample(torch.Size([4]))
print(x)
print(tmn.log_prob(x))

########## output ##########
tensor([[[ 2.7956e-01, -7.6880e-01, -1.5509e+00],
         [ 8.0168e-01,  1.6566e-01,  2.4777e-02]],

        [[-8.6095e-01,  1.9766e-01, -2.6489e-01],
         [-1.0264e+00, -2.0986e-18, -7.3761e-19]],

        [[-9.9605e-02,  2.5579e-02, -9.6922e-02],
         [ 2.0281e+00,  1.4941e-18,  5.9761e-19]],

        [[-7.4705e-01, -1.7805e+00,  2.0846e-01],
         [ 8.1492e-01, -5.5181e-01,  6.8105e-03]]])
event_dim: 1
event_dim - transform.domain.event_dim: 1
transform.log_abs_det_jacobian(x, y): tensor([[-8.1379e-02, -1.0292e+00],
        [-1.3518e+00, -8.2893e+01],
        [-9.9707e-03, -8.2893e+01],
        [-8.1663e-01, -1.0909e+00]])
log_prob: tensor([ 1.1105, 84.2449, 82.9030,  1.9075])
Traceback (most recent call last):
  File "/root/PycharmProjects/OptimalTransport/EBSW/ColorTransfer/zz.py", line 35, in <module>
    print(tmn.log_prob(x))
  File "/opt/programs/pytorch/torch/distributions/transformed_distribution.py", line 177, in log_prob
    log_prob = log_prob + _sum_rightmost(
RuntimeError: The size of tensor a (4) must match the size of tensor b (2) at non-singleton dimension 1

transform.log_abs_det_jacobian(x, y)shape=(4,2) 也即 sample_shape([4]) + batch_shape([2]), 然而却把 batch_shape([2]) 整没了, 然后和

_sum_rightmost(
	self.base_dist.log_prob(y), event_dim - len(self.base_dist.event_shape)  # =0
)

shape=(4,2) 产生冲突, 报错.

即使你转化诸如 Normalevent_dim=0Distribution, 也会出现 shape 不匹配的情况(只有二维样本时形成 scalar 不会报错, 但也计算错误).

[注] 为什么明明是 l o g ∣ d y d x ∣ = m − 1 2 l o g ( 1 − t 2 ) log|\frac{dy}{dx}| = \frac{m-1}{2}log(1-t^2) logdxdy=2m1log(1t2), 代码中却写 (x.shape[-1] - 3) / 2?
答曰: 不清楚.

[小结]: domaincodomain 的限制维度很重要, 该是 real_vector 的写成 real 会出现错误.

4.2 Householder Transform
class _HouseholderRotationTransform(tds.Transform):
	"""
	完成拼接后, 要进行 HouseholderRotation
	"""
	domain = constraints.real
	codomain = constraints.real

	def __init__(self, loc: torch.Tensor):
		super().__init__()
		e1 = torch.zeros_like(loc)  # 继承
		e1[..., 0] = 1.0
		self.__u = tn_func.normalize(e1 - loc, dim=-1)

	def _call(self, x: torch.Tensor):
		return x - 2 * (x * self.__u).sum(-1, keepdim=True) * self.__u

	def _inverse(self, y: torch.Tensor):  # 逆变换是一样的
		return y - 2 * (y * self.__u).sum(-1, keepdim=True) * self.__u

	def log_abs_det_jacobian(self, x: torch.Tensor, y: torch.Tensor):
		# h = torch.eye(x.shape[-1], device=x.device) - 2 * torch.outer(self.__u, self.__u)
		# torch.log(torch.abs(torch.det(h)))
		return 0.0  # 因为 |y|=|x|, 所以 |h|=1; 正交矩阵

具体原理见《householder 变换》. 现在我们关注 log_abs_det_jacobian. y = ( I − 2 u u ⊺ ) x \bm{y} = (I - 2\bm{u}\bm{u}^\intercal) \bm{x} y=(I2uu)x, 所以雅可比矩阵为 J = I − 2 u u ⊺ J = I - 2\bm{u}\bm{u}^\intercal J=I2uu, ∣ d e t ( I − 2 u u ⊺ ) ∣ = 1 |det(I - 2\bm{u}\bm{u}^\intercal)| = 1 det(I2uu)=1, l o g ∣ d e t ( J ) ∣ = 0 log|det(J)| = 0 logdet(J)=0.

同样, 值的计算是没有问题的, 只是其返回值 0.0shape 不对, event_dim - transform.domain.event_dim=1, 你返回一个 0.0, TransformedDistribution 中的 log_prob 无法计算. 直接报错:

AttributeError: 'int' object has no attribute 'shape'

但是, 一旦把 0 改成 torch.tensor(0.0), _sum_rightmost(1) 不会对其有影响, 出来依然是标量, 能广播, 也没有加和的数值错误, 倒成了正确的.

[注]: 正交矩阵的行列式值都为 1 1 1.

4.3 代码的作者自己实现了 log_prob

既然两个转换的 log_abs_det 都有问题, 那为什么还取得了正确的测试结果? 答案是: 作者自己实现了 log_prob 的计算, 而并未使用 TransformedDistribution 中的 log_prob.

def log_prob(self, value):
	return self.log_normalizer() + self.scale * torch.log1p(
		(self.loc * value).sum(-1)
	)

假设我们注释掉这个作者实现的 log_prob, 转而使用 TransformedDistribution 中的 log_prob, 看看会有正确的结果不:

loc = torch.tensor([0.0, 1.0], requires_grad=True)
scale = torch.tensor(4.0, requires_grad=True)
dist = PowerSpherical(loc, scale)

step_size = 0.001
x = torch.arange(0, 2 * math.pi, step_size)
pt = torch.stack((torch.cos(x), torch.sin(x))).t()
y = torch.exp(dist.log_prob(pt)).detach()
print('integal:', y.sum() * step_size)

###################### output #######################
integal: tensor(inf)

竟然没有报错, 只是输出了一个 tensor(inf). 经过检查, 发现忽略了各 Distributionevent_shape, 作者竟然都设置成了默认, 即 torch.Size([]), 且各 Transformforward_shapeinverse_shape 都保持默认, 那么送进 TransformedDistribution 后, base_distribution 不会被 expand, 且 event_shape=torch.Size([]), 进而 event_dim - transform.domain.event_dim = 0 -0 = 0, _sum_rightmost 从来都不会计算.

即使 _HouseholderRotationTransform.log_abs_det_jacobian 返回值为 0, 经过:

log_prob = log_prob - _sum_rightmost(
	transform.log_abs_det_jacobian(x, y),
	event_dim - transform.domain.event_dim,
)

链条的广播计算, 也不是问题. 可以正确计算了? 那为什么是 inf? 问题就在于边缘变量 tlog_prb 的计算, 当 t = 1 t=1 t=1 ( x = μ \bm{x}=\bm{\mu} x=μ) 时, Beta 分布的 log_prob(1)=log(0), 从而出现无穷的情况, 而这本来能乘以该处的均匀子球概率密度 1 0 \frac{1}{0} 01 避免的.

4.4 心得与教训

可以看到, 这种 BaseDistribution + Transform 实现复杂分布的方式提供方便的同时, 也会存在很多问题. 这里总结几点心得与教训:

  • 基础分布和变换都已经存在时, 使用此架构可以快速搭建新的概率分布, 不必考虑太多细节, 比如 log_pdf, cdf 以及一些基础属性;
  • 当需要实现一个复杂分布时, 尽可能拆解成简单分布和变换, 注意要朝着已存在变换的方向分解, 哪怕拆出来的基础分布不在 PyTorch 的包内, 只要能使分布更简单, 就能达到简化的目的;
  • 尽量不要自己写 Transform, 因为要实现的东西有点多, _inverselog_abs_det_jacobian 都比较麻烦, 有那个功夫, 也已经把复杂分布的 pdf 算出来了;
  • 尽量直接继承 Distribution 实现自己的分布, 为不是使用转换链. 因为转换意味着要计算 log_abs_det_jacobian, 然后沿着转换链累加, 这比直接计算目标分布的 log_prob 增加了许多计算, 且承担不稳定的风险. 比如此例子中的接连三个变换, 计算复杂不说, 还拆开了 x = μ \bm{x} = \bm{\mu} x=μ 时的 f ( t = 1 ) = 0 f(t=1) = 0 f(t=1)=0 f ( 1 − t 2 v ) = 1 S p − 2 = 1 0 f(\sqrt{1-t^2}\bm{v}) = \frac{1}{S_{p-2}} = \frac{1}{0} f(1t2 v)=Sp21=01, 加之为了避免 0 0 0 作为分母而加上的 e p s eps eps 偏移, 使得计算失败.
  • 如果只是为了采样, 则不必考虑 _inverselog_abs_det_jacobian, 因为它们只出现在 log_prob.
4.5 关于 batch_shape, event_shapeConstraint.event_dim.

如果迫不得已需要使用转换链, 我还是建议尽量实事求是地把这三个参数正确地写上, 而不是全部保持为 0. 一则不符合代码逻辑, 二来不定在哪就出错了. 包的作者既然这么设置, 肯定有他的道理.

附录

1. __debug__assert (来自 Kimi)

__debug__ 是一个内置变量,用于指示 Python 解释器是否处于调试模式。当 Python 以调试模式运行时,__debug__ 被设置为 True;否则,在优化模式下运行时,它被设置为 False

__debug__ 可以用于条件性地执行调试代码,例如:

if __debug__:
	print("Debug mode is on, performing extra checks...")
	# 这里可以放一些只在调试模式下运行的代码,比如详细的日志记录
	# 或者复杂的验证逻辑
else:
	print("Debug mode is off.")

在上面的例子中,如果命令行执行:

python -O myscript.py
##### output #####
Debug mode is off.
------------------------------------------------------
python myscript.py
##### output #####
Debug mode is on, performing extra checks...

assert 语句受 __debug__ 影响:

def calculate(a, b):
	# 这个 assert 在 __debug__ 为 True 时执行
	assert a > 0 and b > 0, "Both inputs must be positive."
	
	# 正常的函数逻辑
	return a * b

# 在这里,assert 会检查输入是否为正数
result = calculate(5, 3)
print(result)

# 如果我们改变条件使 assert 失败
# result = calculate(-1, 3)  # 这会触发 AssertionError,除非运行时 __debug__ 为 False
2. t t t 的概率密度函数推导

直接将 t = μ ⊺ x t = \bm{\mu}^\intercal\bm{x} t=μx 代入 f p ( x ; μ , κ ) f_p(\bm{x}; \bm{\mu}, \kappa) fp(x;μ,κ), 得: f p ( x ; μ , κ ) = C p ( κ ) ( 1 + μ ⊺ x ) κ = C p ( κ ) ( 1 + t ) κ t ∈ [ − 1 , 1 ] = C p ( κ ) ( 1 + c o s θ ) κ θ ∈ [ 0 , π ] \begin{aligned} f_p(\bm{x}; \bm{\mu}, \kappa) &= C_p(\kappa) (1 + \bm{\mu}^\intercal\bm{x})^\kappa & \\ &= C_p(\kappa) (1 + t)^\kappa & t \in [-1, 1] \\ &= C_p(\kappa) (1 + cos\theta)^\kappa & \theta \in [0, \pi] \end{aligned} fp(x;μ,κ)=Cp(κ)(1+μx)κ=Cp(κ)(1+t)κ=Cp(κ)(1+cosθ)κt[1,1]θ[0,π] 注意这是 x \bm{x} x 一点处的概率密度. 沿着 t t t 处的切子球求积分, 以得到 t t t θ \theta θ 处的整个概率密度: ∫ 切子球 f p ( x ; μ , κ ) d s = ∫ 切子球 C p ( κ ) ( 1 + μ ⊺ x ) κ d s = C p ( κ ) ( 1 + t ) κ 2 π p − 1 2 Γ ( p − 1 2 ) ( 1 − t 2 ) p − 2 2 S p − 2 的表面积 ∝ r p − 2 = C p ( κ ) ( 1 + c o s θ ) κ 2 π p − 1 2 Γ ( p − 1 2 ) s i n p − 2 θ \begin{aligned} \int_{切子球} f_p(\bm{x}; \bm{\mu}, \kappa) ds &= \int_{切子球} C_p(\kappa) (1 + \bm{\mu}^\intercal\bm{x})^\kappa ds & \\ &= C_p(\kappa) (1 + t)^\kappa\frac{2\pi^{\frac{p-1}{2}}}{\Gamma(\frac{p-1}{2})}(1-t^2)^\frac{p-2}{2} & S^{p-2} 的表面积 \propto r^{p-2} \\ &= C_p(\kappa) (1 + cos\theta)^\kappa \frac{2\pi^{\frac{p-1}{2}}}{\Gamma(\frac{p-1}{2})} sin^{p-2}\theta & \end{aligned} 切子球fp(x;μ,κ)ds=切子球Cp(κ)(1+μx)κds=Cp(κ)(1+t)κΓ(2p1)2π2p1(1t2)2p2=Cp(κ)(1+cosθ)κΓ(2p1)2π2p1sinp2θSp2的表面积rp2 根据 n-sphere - Wikipedia, 切子球 S p − 2 S^{p-2} Sp2 的表面积 S p − 2 = 2 π p − 1 2 Γ ( p − 1 2 ) r p − 2 S_{p-2} = \frac{2\pi^{\frac{p-1}{2}}}{\Gamma(\frac{p-1}{2})} r^{p-2} Sp2=Γ(2p1)2π2p1rp2, 再沿 t t t θ \theta θ 积分: ∫ 0 π C p ( κ ) ( 1 + c o s θ ) κ 2 π p − 1 2 Γ ( p − 1 2 ) s i n p − 2 θ d θ = C p ( κ ) 2 π p − 1 2 Γ ( p − 1 2 ) ∫ 1 − 1 ( 1 + t ) κ ( 1 − t 2 ) p − 2 2 ( − 1 1 − t 2 d t ) ∵ c o s 0 = 1 ,   c o s π = − 1 = C p ( κ ) 2 π p − 1 2 Γ ( p − 1 2 ) ∫ − 1 1 ( 1 + t ) κ ( 1 − t 2 ) p − 3 2 d t = C p ( κ ) 2 π p − 1 2 Γ ( p − 1 2 ) ∫ − 1 1 ( 1 + t ) κ [ ( 1 + t ) ( 1 − t ) ] p − 3 2 d t = C p ( κ ) 2 π p − 1 2 Γ ( p − 1 2 ) ∫ − 1 1 ( 1 + t ) p − 1 2 + κ − 1 ( 1 − t ) p − 1 2 − 1 d t = C p ( κ ) 2 π p − 1 2 Γ ( p − 1 2 ) ∫ 0 1 ( 2 z ) p − 1 2 + κ − 1 ( 2 ( 1 − z ) ) p − 1 2 − 1 2 d z t = 2 z − 1 = C p ( κ ) 2 p + κ − 1 π p − 1 2 Γ ( p − 1 2 ) ∫ 0 1 z p − 1 2 + κ − 1 ( 1 − z ) p − 1 2 − 1 d z \begin{aligned} & \int_0^\pi C_p(\kappa) (1 + cos\theta)^\kappa \frac{2\pi^{\frac{p-1}{2}}}{\Gamma(\frac{p-1}{2})} sin^{p-2}\theta d\theta \\ =& C_p(\kappa) \frac{2\pi^{\frac{p-1}{2}}}{\Gamma(\frac{p-1}{2})} \int_{1}^{-1} (1 + t)^\kappa (1-t^2)^\frac{p-2}{2} (\frac{-1}{\sqrt{1-t^2}} dt) & \because cos0=1,~ cos\pi=-1 \\ =& C_p(\kappa) \frac{2\pi^{\frac{p-1}{2}}}{\Gamma(\frac{p-1}{2})} \int_{-1}^{1} (1 + t)^\kappa (1-t^2)^{\frac{p-3}{2}} dt \\ =& C_p(\kappa) \frac{2\pi^{\frac{p-1}{2}}}{\Gamma(\frac{p-1}{2})} \int_{-1}^{1} (1 + t)^\kappa [(1+t)(1-t)]^{\frac{p-3}{2}} dt \\ =& C_p(\kappa) \frac{2\pi^{\frac{p-1}{2}}}{\Gamma(\frac{p-1}{2})} \int_{-1}^{1} (1 + t)^{\frac{p-1}{2}+\kappa-1} (1-t)^{\frac{p-1}{2}-1} dt \\ =& C_p(\kappa) \frac{2\pi^{\frac{p-1}{2}}}{\Gamma(\frac{p-1}{2})} \int_{0}^{1} (2z)^{\frac{p-1}{2}+\kappa-1} (2(1-z))^{\frac{p-1}{2}-1} 2dz & t=2z-1 \\ =& C_p(\kappa) \frac{2^{p+\kappa-1}\pi^{\frac{p-1}{2}}}{\Gamma(\frac{p-1}{2})} \int_{0}^{1} z^{\frac{p-1}{2}+\kappa-1} (1-z)^{\frac{p-1}{2}-1} dz \end{aligned} ======0πCp(κ)(1+cosθ)κΓ(2p1)2π2p1sinp2θdθCp(κ)Γ(2p1)2π2p111(1+t)κ(1t2)2p2(1t2 1dt)Cp(κ)Γ(2p1)2π2p111(1+t)κ(1t2)2p3dtCp(κ)Γ(2p1)2π2p111(1+t)κ[(1+t)(1t)]2p3dtCp(κ)Γ(2p1)2π2p111(1+t)2p1+κ1(1t)2p11dtCp(κ)Γ(2p1)2π2p101(2z)2p1+κ1(2(1z))2p112dzCp(κ)Γ(2p1)2p+κ1π2p101z2p1+κ1(1z)2p11dzcos0=1, cosπ=1t=2z1 那么, 将 α = p − 1 2 + κ β = p − 1 2 C p ( κ ) 2 α + β π β Γ ( β ) = Γ ( α + β ) Γ ( α ) Γ ( β ) C p ( κ ) = Γ ( α + β ) 2 α + β π β Γ ( α ) \begin{aligned} \alpha =& \frac{p-1}{2}+\kappa \\ \beta =& \frac{p-1}{2} \\ C_p(\kappa) \frac{2^{\alpha+\beta}\pi^{\beta}}{\Gamma(\beta)} =& \frac{\Gamma(\alpha+\beta)}{\Gamma(\alpha)\Gamma(\beta)} \\ C_p(\kappa) =& \frac{\Gamma(\alpha+\beta)}{2^{\alpha+\beta}\pi^{\beta}\Gamma(\alpha)} \end{aligned} α=β=Cp(κ)Γ(β)2α+βπβ=Cp(κ)=2p1+κ2p1Γ(α)Γ(β)Γ(α+β)2α+βπβΓ(α)Γ(α+β)

3. ComposeTransformdomain.event_dim
def domain(self):
	if not self.parts:
		return constraints.real
	domain = self.parts[0].domain
	# Adjust event_dim to be maximum among all parts.
	event_dim = self.parts[-1].codomain.event_dim
	for part in reversed(self.parts):
		event_dim += part.domain.event_dim - part.codomain.event_dim  # 一个 Transform 的 event_dim 降了多少, 加回来
		# 如果当前 Transform 的 codomain.event_dim 等于下一个 Transform 的 domain.event_dim, 那么下面两者一定相等;
		event_dim = max(event_dim, part.domain.event_dim)  # Transform 前后口对接不一致才会出现差别, 且大概率是前一个出口大于后一个入口 ->
	# -> 假设现在 event_dim 等于后一个 domain.event_dim(事实上刚开始是这样的), 那么 event_dim - part.codomain.event_dim < 0 ->
	assert event_dim >= domain.event_dim  # -> 意味着 event_dim += part.domain.event_dim - part.codomain.event_dim < part.domain.event_dim
	# 从而 event_dim = max(event_dim, part.domain.event_dim) = part.domain.event_dim, 还是和接口一致相同的结果.
	# 所以, 只有前一个出口小于后一个入口时, 才会保留计算得到的 event_dim, 而不是更新为 part.domain.event_dim
	# ===> 下面一句是结论:
	if event_dim > domain.event_dim:  # 这个条件正是说明链条中有"前一个出口小于后一个入口"的接口不兼容情况, 故而扩展开头的 domain, 以免出错; ->
		domain = constraints.independent(domain, event_dim - domain.event_dim)  # 目的是留够 dim, 足够后面的"小->大"消耗不兼容性.
	return domain
4. TTransformlog_abs_det_jacobian

经过试验, 发现的确是 ((x.shape[-1] - 3) / 2) * torch.log(torch.clamp(1 - t ** 2, _EPS)) 对:

class MarginalTDistribution(distributions.Distribution):
	arg_constraints = {'kappa': constraints.positive}
	has_rsample = True
	_EPS = 1e-36

	def __init__(self, m, kappa, dtype=None, validate_args=None):
		self.m = m
		self.kappa = kappa
		super().__init__(batch_shape=kappa.shape[:-1], event_shape=kappa.shape[-1:], validate_args=validate_args)

		self._dtype = dtype if dtype is not None else kappa.dtype
		self._device = kappa.device

		# >>> for sampling algorithm >>>
		self._uniform = distributions.Uniform(self._EPS, torch.tensor(1.0 - self._EPS, dtype=self._dtype, device=self._device))
		self._beta = distributions.Beta(
			torch.tensor((self.m - 1) / 2, dtype=self._dtype, device=self._device),
			torch.tensor((self.m - 1) / 2, dtype=self._dtype, device=self._device)
		)

	def sample(self, shape: Union[torch.Size, int] = torch.Size()):
		if isinstance(shape, int):
			shape = torch.Size([shape])
		with torch.no_grad():  # rsample 是 reparameterized sample, 便于梯度更新以调整分布参数
			return self.rsample(shape)

	def rsample(self, shape=torch.Size()):
		"""
		Reparameterized Sample: 从一个简单的分布通过一个参数化变换使得其满足一个更复杂的分布;
		此处, mu 是可变参数, 通过 radial-tangential decomposition 采样;
		梯度下降更新 mu, 以获得满足要求的 vMF.
		:param shape: 样本的形状
		:return: [shape|m] 的张量, shape 个 m 维方向向量
		"""
		# shape = torch.Size(shape + self._batch_shape + self._event_shape)
		w = (
			self._sample_w3(shape=shape)
			if self.m == 3
			else self._sample_w_rej(shape=shape)
		)
		return w

	def _sample_w3(self, shape: torch.Size):
		"""
		拒绝采样方法来自: Computer Generation of Distributions on the M-Sphere
		https://rss.onlinelibrary.wiley.com/doi/abs/10.2307/2347441
		:param shape: 采样 w 的的形状
		:return: 形状为 shape 的张量, shape 个 w
		"""
		# kappa 也是有 shape 的, 说明可以并行多个 κ 吗?
		shape = torch.Size(shape + self.kappa.shape)  # torch.Size 继承自 tuple, 其 + 运算就是连接操作
		# https://en.wikipedia.org/wiki/Von_Mises%E2%80%93Fisher_distribution # 3-D sphere
		u = self._uniform.sample(shape)
		w = 1 + torch.stack(  # 这个公式是按 μ=(0,0,1) 计算的 w, arccosw=φ, 即 w=z
			[  # 最后的旋转可能是旋转至按真正的 μ 采样结果
				torch.log(u),
				torch.log(1 - u) - 2 * self.kappa
			],
			dim=0
		).logsumexp(0) / self.kappa
		return w

	def _sample_w_rej(self, shape: torch.Size):
		num_need = math.prod(shape)
		num_kappa = math.prod(self.kappa.shape)
		sample_shape = torch.Size([num_kappa, 10 + math.ceil(num_need * 1.3)])

		kappa = self.kappa.reshape(-1, 1)
		c = torch.sqrt(4 * kappa.square() + (self.m - 1) ** 2)
		b = (-2 * kappa + c) / (self.m - 1)
		t_0 = (-(self.m - 1) + c) / (2 * kappa)
		s = kappa * t_0 + (self.m - 1) * torch.log(1 - t_0.square())

		# 大天坑, [[]] * num_kappa 虽然也形成了 [num_kappa]个[]], 但实际上这些 [] 是同一个对象;
		# 导致后面的 ts[i] 全是同一个 [], 并 append 了所有不同 kappa 对应的 t 样本;
		# [torch.cat(t)[:num_need] for t in ts] 中每个 tensor 也仅仅获取了第一个 kappa 对应的 t 样本;
		# 因为 Python 的 * 运算符是浅拷贝, 而不是深拷贝.
		ts = [[] for _ in range(num_kappa)]  # * num_kappa
		cnts = torch.zeros(num_kappa, device=self._device)
		while cnts.lt(num_need).any():
			y = self._beta.sample(sample_shape)
			u = self._uniform.sample(sample_shape)
			t = (1 - (1 + b) * y) / (1 - (1 - b) * y)
			mask = (kappa * t + (self.m - 1) * torch.log(1 - t_0 * t) - s) > torch.log(u)
			for i in range(num_kappa):
				ts[i].append(t[i][mask[i]])
			cnts += mask.sum(dim=-1)
		samples = torch.stack([torch.cat(t)[:num_need] for t in ts], dim=1)
		return samples.reshape(shape + self.kappa.shape)

	def log_prob(self, value: torch.Tensor) -> torch.Tensor:
		result = (self.m / 2 - 1) * torch.log(self.kappa / 2)
		result -= math.lgamma(0.5) + math.lgamma(self.m / 2 - 0.5) + torch.log(ive(self.m / 2 - 1, self.kappa)) + self.kappa
		result += value * self.kappa + (self.m / 2 - 1 - 0.5) * torch.log(1 - value.square())
		return result


class _JointTSDistribution(distributions.Distribution):
	arg_constraints = {}

	def __init__(self, marginal_t, marginal_s):
		super().__init__(batch_shape=marginal_s.batch_shape, event_shape=torch.Size([marginal_s._dim + 1]), validate_args=False)
		self.marginal_t, self.marginal_s = marginal_t, marginal_s

	def sample(self, sample_shape=()):
		with torch.no_grad():
			return self.rsample(sample_shape)

	def rsample(self, sample_shape=()):
		return torch.cat(
			[
				self.marginal_t.rsample(sample_shape).unsqueeze(-1),
				self.marginal_s.sample(sample_shape + self.marginal_t.kappa.shape),
			],
			-1
		)

	def log_prob(self, value):
		return self.marginal_t.log_prob(value[..., 0]) + self.marginal_s.log_prob(value[..., 1:])


class VonMisesFisher(distributions.TransformedDistribution):
	arg_constraints = {
		'loc': constraints.real_vector,
		'scale': constraints.positive,
	}
	has_rsample = True

	def __init__(self, loc, scale, validate_args=None):
		self.loc, self.scale, = loc, scale
		super().__init__(
			_JointTSDistribution(
				MarginalTDistribution(
					loc.shape[-1], scale, validate_args=validate_args
				),
				HypersphericalUniform(
					loc.shape[-1] - 1,
					batch_shape=loc.shape[:-1],
					validate_args=validate_args,
				),
			),
			[TTransform(), HouseholderRotationTransform(loc)],
		)

	def log_prob_true(self, x):
		if self._validate_args:
			self._validate_sample(x)
		return self._log_unnormalized_prob(x) + self._log_normalization()

	def _log_unnormalized_prob(self, x):  # k<μ,x>
		kappa = self.base_dist.marginal_t.kappa
		mu = self.loc
		return kappa * (mu * x).sum(-1, keepdim=True)

	def _log_normalization(self):  # logCp(kappa)
		m = self.loc.shape[-1]
		kappa = self.base_dist.marginal_t.kappa
		return (
				(m / 2 - 1) * torch.log(kappa)
				- (m / 2) * math.log(2 * math.pi)
				- torch.log(ive(m / 2 - 1, kappa)) - kappa
		)


if __name__ == '__main__':
	vmf = VonMisesFisher(F.normalize(torch.randn(8), dim=-1), torch.tensor([3.0]))
	x = vmf.sample(torch.Size([2]))
	print(x)
	print(vmf.log_prob(x))
	print(vmf.log_prob_true(x))

	from scipy import stats
	svmf = stats.vonmises_fisher(vmf.loc.numpy(), vmf.scale.numpy().item())
	print(svmf.logpdf(x.detach().numpy()))

但我始终想不明白为什么.

According to TTransform, we have [ t , 1 − t 2 v 1 , ⋯   , 1 − t 2 v m − 1 ] = t t r a n s f o r m ( [ t , v 1 , ⋯   , v m − 1 ] ) [t, \sqrt{1-t^2}v_1, \cdots, \sqrt{1-t^2}v_{m-1}] = ttransform([t, v_1, \cdots, v_{m-1}]) [t,1t2 v1,,1t2 vm1]=ttransform([t,v1,,vm1]). Then the Jacobian matrix: J = [ 1 0 ⋯ 0 − t v 1 1 − t 2 1 − t 2 ⋮ 0 ⋮ ⋮ ⋱ ⋮ − t v m − 1 1 − t 2 0 ⋯ 1 − t 2 ] J = \begin{bmatrix} 1 & 0 & \cdots & 0 \\ \frac{-tv_1}{\sqrt{1-t^2}} & \sqrt{1-t^2} & \vdots & 0 \\ \vdots & \vdots & \ddots & \vdots \\ \frac{-tv_{m-1}}{\sqrt{1-t^2}}& 0 & \cdots & \sqrt{1-t^2} \end{bmatrix} J= 11t2 tv11t2 tvm101t2 0001t2 ⇒ d e t ( J ) = ( 1 − t 2 ) m − 1 2 \Rightarrow det(J) = (1-t^2)^\frac{m-1}{2} det(J)=(1t2)2m1, I get the result l o g ∣ J ∣ = m − 1 2 l o g ( 1 − t 2 ) log|J| = \frac{m-1}{2}log(1-t^2) logJ=2m1log(1t2).

### 关于《强化学习导论》第13章的内容 #### 13.1 政策梯度方法简介 政策梯度方法代表了一类通过参数化策略直接优化性能指标的方法。这类方法不依赖价值函数作为中介,而是直接调整行动决策的概率分布以最大化长期奖励期望。这种方法允许处理连续动作空间的问题,在机器人控制等领域具有广泛应用前景[^2]。 #### 13.2 REINFORCE算法及其变体 REINFORCE是一种基于采样的无模型策略梯度算法,它利用蒙特卡罗估计计算目标函数相对于策略参数的梯度。该章节深入探讨了如何有效减少方差的技术以及引入基线来提高收敛速度和稳定性的方式。此外还讨论了几种改进版本如带资格迹的REINFORCE等[^3]。 #### 13.3 动作-价值方法与演员评论家架构 为了克服仅依靠回报信号指导探索过程所带来的高方差问题,《强化学习导论》介绍了结合动作价值评估机制的动作-批评者框架。这种双网络结构不仅能够加速学习进程而且可以更好地平衡勘探与开发之间的关系。具体实现形式包括A3C (Asynchronous Advantage Actor-Critic)[^4]。 #### 13.4 近端策略优化(PPO) 近端策略优化旨在解决传统策略梯度方法容易发散的问题。PPO通过对每次迭代中的更新幅度施加限制条件从而确保新旧策略之间不会发生剧烈变化。此部分讲解了不同类型的约束方式及其对应的优缺点分析,并提供了实际应用案例说明其有效性[^5]。 ```python import torch.nn as nn from torch.distributions import Categorical class PolicyNetwork(nn.Module): def __init__(self, input_dim, hidden_dim, output_dim): super().__init__() self.network = nn.Sequential( nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, output_dim), nn.Softmax(dim=-1)) def forward(self, state): action_probs = self.network(state.float()) dist = Categorical(action_probs) return dist.sample(), dist.log_prob(dist.sample()) policy_net = PolicyNetwork(4, 128, 2) # Example initialization with CartPole environment dimensions. ```
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值