Wigner D-矩阵
在计算球面卷积(spherical CNN)的时候,对图像和卷积核进行傅里叶变换,然后通过矩阵相乘和傅里叶逆变换,来进行卷积。其中,图像就是球面图像,第一层卷积网络的卷积核是在s2球面上卷积,从第二层开始后面的卷积核都是在SO(3)群上的卷积。
Wigner D 矩阵是SU(2)和SO(3)群的不可约表示中的酉矩阵。它由尤金·维格纳 (Eugene Wigner)于 1927 年提出,在角动量的量子力学理论中起着基础性作用。D 矩阵的复共轭是球形和对称刚性转子的哈密顿量的特征函数。Wigner D-矩阵就是卷积过程中SO(3)群的基底,相当于基础的傅里叶变换中的sin 和cos。
基础知识
若我们把球谐函数
Y
l
,
m
(
r
^
)
Y_{l, m}(\hat{\mathbf{r}})
Yl,m(r^)绕原点进行某种旋转,得到的函数可以表示成球谐函数的线性组合,且只需要同一个
l
l
l子空间中的球谐函数,若将旋转算符(主动)记为
R
R
R,则
R
∣
l
,
m
⟩
=
∑
m
′
∣
l
,
m
′
⟩
⟨
l
,
m
′
∣
R
∣
l
,
m
⟩
\mathcal{R}|l, m\rangle=\sum_{m^{\prime}}\left|l, m^{\prime}\right\rangle\left\langle l, m^{\prime}|\mathcal{R}| l, m\right\rangle
R∣l,m⟩=m′∑∣l,m′⟩⟨l,m′∣R∣l,m⟩
我们把系数矩阵称为 Wigner D 矩阵
D
m
′
,
m
l
=
⟨
l
,
m
′
∣
R
∣
l
,
m
⟩
.
D_{m^{\prime}, m}^l=\left\langle l, m^{\prime}|\mathcal{R}| l, m\right\rangle .
Dm′,ml=⟨l,m′∣R∣l,m⟩.
代码
e3nn
def change_basis_real_to_complex(l: int, dtype=None, device=None) -> torch.Tensor:
# https://en.wikipedia.org/wiki/Spherical_harmonics#Real_form
q = torch.zeros((2 * l + 1, 2 * l + 1), dtype=torch.complex128)
for m in range(-l, 0):
q[l + m, l + abs(m)] = 1 / 2**0.5
q[l + m, l - abs(m)] = -1j / 2**0.5
q[l, l] = 1
for m in range(1, l + 1):
q[l + m, l + abs(m)] = (-1) ** m / 2**0.5
q[l + m, l - abs(m)] = 1j * (-1) ** m / 2**0.5
q = (-1j) ** l * q # Added factor of 1j**l to make the Clebsch-Gordan coefficients real
dtype, device = explicit_default_types(dtype, device)
dtype = {
torch.float32: torch.complex64,
torch.float64: torch.complex128,
}[dtype]
# make sure we always get:
# 1. a copy so mutation doesn't ruin the stored tensors
# 2. a contiguous tensor, regardless of what transpositions happened above
return q.to(dtype=dtype, device=device, copy=True, memory_format=torch.contiguous_format)
def su2_generators(j) -> torch.Tensor:
m = torch.arange(-j, j)
raising = torch.diag(-torch.sqrt(j * (j + 1) - m * (m + 1)), diagonal=-1)
m = torch.arange(-j + 1, j + 1)
lowering = torch.diag(torch.sqrt(j * (j + 1) - m * (m - 1)), diagonal=1)
m = torch.arange(-j, j + 1)
return torch.stack(
[
0.5 * (raising + lowering), # x (usually)
torch.diag(1j * m), # z (usually)
-0.5j * (raising - lowering), # -y (usually)
],
dim=0,
)
def so3_generators(l) -> torch.Tensor:
X = su2_generators(l)
Q = change_basis_real_to_complex(l)
X = torch.conj(Q.T) @ X @ Q
assert torch.all(torch.abs(torch.imag(X)) < 1e-5)
return torch.real(X)
def wigner_D(l, alpha, beta, gamma):
alpha, beta, gamma = torch.broadcast_tensors(alpha, beta, gamma)
alpha = alpha[..., None, None] % (2 * math.pi)
beta = beta[..., None, None] % (2 * math.pi)
gamma = gamma[..., None, None] % (2 * math.pi)
X = so3_generators(l)
return torch.matrix_exp(alpha * X[1]) @ torch.matrix_exp(beta * X[0]) @ torch.matrix_exp(gamma * X[1])