SE(3) 等变的生成扩散模型
在蛋白质结构生成中遇到的一篇看到作者单位就感觉很难读的文献🥹(SE(3) diffusion model with application to protein backbone generation)
为了后续的研究需要,学习一下SE(3)等变的生成扩散模型。
1 扩散模型初步
之前组会报告准备过一次关于一般扩散模型的内容扩散模型简介,但是模型不具有SE(3)等变性质。
2 什么是 SE(3)?为什么需要SE(3)?
2.1. SE(3) 的定义
SE(3)(Special Euclidean group in 3D,即三维特殊欧几里得群),它包括旋转和平移两部分。数学上,SE(3) 定义为:
\[ SE(3) = \{(R, x) \mid R \in SO(3), x \in \mathbb{R}^3\} \]其中:
- \( R \in SO(3) \) 是一个 \( 3 \times 3 \) 正交矩阵(即 \( R^\top R = I \))且 \(\det(R) = 1\),表示三维旋转。
- \( x \in \mathbb{R}^3 \) 是一个三维平移向量,表示刚体在三维空间的平移。
SE(3) 的群运算是:
\[ (R_1, x_1) \cdot (R_2, x_2) = (R_1 R_2, R_1 x_2 + x_1) \]它表示一个刚体变换 \( (R_2, x_2) \) 作用后,再应用变换 \( (R_1, x_1) \) 的结果。
2.2. 为什么需要 SE(3)?
涉及三维刚体的问题中,特别是在分子模拟、机器人学、计算机视觉等领域,物体的状态不仅仅是一个位置(平移),还包括方向(旋转)。SE(3) 提供了一个数学框架,使得我们可以在同一个空间中处理这些信息。
在 SE(3) 空间上进行建模的一个关键优势是,它能够保持物理和几何的不变性。具体来说:
- 在蛋白质建模中,蛋白质的结构和功能通常仅取决于其相对构象,而不是其绝对位置或方向。因此,我们希望构造的模型具有对称性,不会因坐标选取不同而导致结果不同。
- 在机器人学中,机器人手臂、自动驾驶汽车、无人机的姿态(pose)通常在 SE(3) 上进行建模,以便在不同的参考系之间转换。
- 在计算机视觉中,物体识别和跟踪通常需要对旋转和平移具有鲁棒性,因此 SE(3) 变换在 3D 目标检测和姿态估计任务中十分重要。
2.3. 蛋白质中的 SE(3) 是什么?
在蛋白质结构建模中,平移通常指的是 Cα 的平移 ,而旋转则涉及整个框架(frame)的旋转。
每个氨基酸残基由一组固定的主链原子(N、Cα、C、O)组成,这些原子在理想化情况下保持相对固定的几何关系。旋转的核心是这个刚性框架(frame)的方向,即以 Cα 作为中心的旋转。
蛋白质的主链是通过共价键连接的,每个残基的 N、Cα、C 具有固定的几何关系(键长、键角)。因此,在建模过程中,我们可以将每个残基视为一个刚体,即只允许其进行整体的刚体旋转,而不改变内部键长和键角。
数学上,我们用一个 SE(3) 变换 \( T = (R, x) \) 来表示残基的空间位姿:
- \( x \in \mathbb{R}^3 \) :表示 Cα 的平移位置(即氨基酸的中心位置)。
- \( R \in SO(3) \) :表示该残基在三维空间中的 方向 ,即 N、Cα、C 三个原子构成的刚体的旋转。
由于 SE(3) 变换是作用在刚体上的,我们需要一个合适的坐标系来描述残基的方向。通常,我们使用Gram-Schmidt 正交化来构造一个正交基,以描述 N、Cα、C 的相对方向,从而定义\(R\)。
如果我们用 \( T_n = (R_n, x_n) \) 表示第 \( n \) 个残基的 pose,那么:
\[ [N_n, C_n, (Cα)_n] = T_n \cdot [N^*, C^*, Cα^*], \]其中:\( [N^*, C^*, Cα^*] \) 是理想化残基的标准坐标。\( T_n = (R_n, x_n) \) 作用在这些坐标上,给出第 \( n \) 个残基在三维空间中的实际位置。
3. SE(3) 扩散模型
3.1. SE(3) 扩散模型的基本思想
为了在 SE(3) 上进行扩散建模,我们需要定义 SE(3) 上的随机过程。然但,在 SE(3) 上直接进行扩散会带来计算上的复杂性。通过对 SE(3) 的度规进行适当的选择,我们可以证明:
\[ dT(t) = (dR(t), dX(t)), \]即 SE(3) 的扩散过程可以拆分成:
- SO(3) 上的旋转扩散 \( R(t) \) (用于建模残基刚体框架的方向)。
- \( \mathbb{R}^3 \) 上的平移扩散 \( X(t) \) (用于建模 Cα 原子的位置)。
这一拆分的合理性基于:
- SE(3) 可以用 SO(3) 和 ℝ³ 的直积表示,因此可以在两个部分分别定义扩散过程。
- 在适当的度规下,SE(3) 上的拉普拉斯–贝尔特拉米算子(Riemannian 流形上的 Laplace 算子)可以拆分成 SO(3) 和 ℝ³ 上的拉普拉斯算子之和 : \[ \Delta_{SE(3)} = \Delta_{SO(3)} + \Delta_{\mathbb{R}^3}. \]
- SO(3) 上的扩散可以用李群上的布朗运动建模,而 ℝ³ 上的扩散可以用经典的 Ornstein–Uhlenbeck 过程建模。
拉普拉斯–贝尔特拉米算子
在 \( \mathbb{R}^n \) 中,标准的拉普拉斯算子(Laplace Operator)定义为:
\[ \Delta f = \nabla \cdot \nabla f = \sum_{i=1}^{n} \frac{\partial^2 f}{\partial x_i^2}. \]它表示标量场 \( f(x) \) 的二阶导数之和,通常用于描述:
对于一个 一般的 Riemannian 流形 \( (M, g) \) ,由于坐标系统不再是欧几里得直角坐标,我们需要用度规 \( g \) 来定义拉普拉斯算子,即拉普拉斯–贝尔特拉米算子。
\[ \Delta_M f = \text{div} (\text{grad} f). \]- 梯度算子(Gradient Operator) : \[ \text{grad} f = g^{ij} \frac{\partial f}{\partial x^j} \frac{\partial}{\partial x^i}. \] 其中 \( g^{ij} \) 是度规张量 \( g_{ij} \) 的逆矩阵。
- 散度算子(Divergence Operator) : \[ \text{div} X = \frac{1}{\sqrt{|g|}} \frac{\partial}{\partial x^i} \left( \sqrt{|g|} X^i \right). \] 其中 \( |g| = \det(g_{ij}) \) 是度规张量的行列式。
在局部坐标系 \( \{x^i\} \) 下,拉普拉斯–贝尔特拉米算子的显式形式为:
\[ \Delta_M f = \frac{1}{\sqrt{|g|}} \sum_{i,j} \frac{\partial}{\partial x^i} \left( \sqrt{|g|} g^{ij} \frac{\partial f}{\partial x^j} \right). \]这一公式在一般流形上定义了拉普拉斯算子,它依赖于流形的几何结构(即度量 \( g \)),因此在前面提到的 SE(3) 上的度规就产生了作用。
SE(3) 上的度规
这篇文章在 SE(3) 上选取的度量张量 \( g \) 使得 SE(3) 作为 Riemannian 流形可以分解为 \( SO(3) \) 和 \( \mathbb{R}^3 \) 的直积,并且保证旋转和平移部分的扩散过程可以独立建模。
对于 SE(3) 上的任意两个切向量 \( (a, x), (a', x') \in \text{Tan}_T SE(3) \),本文选取的内积定义如下:
\[ \langle (a, x), (a', x') \rangle_{SE(3)} = \langle a, a' \rangle_{SO(3)} + \langle x, x' \rangle_{\mathbb{R}^3}. \]其中:
- 旋转部分的度规(即 \( SO(3) \) 部分): \[ \langle a, a' \rangle_{SO(3)} = \frac{1}{2} \text{Tr}(a a'^T). \] 这里 \( a, a' \in so(3) \) 是 \( SO(3) \) 李代数中的元素(即反对称矩阵)。
- 平移部分的度规 (即 \( \mathbb{R}^3 \) 部分): \[ \langle x, x' \rangle_{\mathbb{R}^3} = \sum_{i=1}^{3} x_i x'_i. \] 这里 \( x, x' \in \mathbb{R}^3 \) 是标准欧几里得空间中的向量。
这种选取使得 SE(3) 在 Riemannian 意义下可视为 SO(3) 和 \( \mathbb{R}^3 \) 的直积 ,从而能够独立处理旋转和平移的扩散过程。并且拉普拉斯–贝尔特拉米算子可以分解为 \( \Delta_{SE(3)} = \Delta_{SO(3)} + \Delta_{\mathbb{R}^3} \) ,从而扩散过程可以分别建模。旋转的 Frobenius 范数对应于经典刚体动力学中的角动量度量。平移的欧几里得度量对应于质点的物理运动。
3.2. \( \mathbb{R}^3 \) 上的扩散
在 欧几里得空间 \( \mathbb{R}^3 \) 中 ,扩散过程可以使用 Ornstein–Uhlenbeck(OU)过程建模:
前向扩散过程
平移扩散建模 Cα 原子的位移,定义如下:
\[ dX(t) = -\frac{1}{2} X(t) dt + dB_{\mathbb{R}^3}(t), \]其中:\( X(t) \) 是 Cα 的位置。\( B_{\mathbb{R}^3}(t) \) 是标准布朗运动(Brownian motion)。漂移项 \( -\frac{1}{2} X(t) dt \) 使得扩散过程收敛到均匀噪声分布。
这个过程的解为:
\[ X(t) = e^{-t/2} X(0) + \sqrt{1 - e^{-t}} \cdot \epsilon, \quad \epsilon \sim \mathcal{N}(0, I_3). \]其 转移概率密度是:
\[ p_t(X | X_0) = \mathcal{N}(X; e^{-t/2} X_0, (1 - e^{-t}) I_3). \]逆扩散过程
逆扩散过程是从噪声恢复数据,其公式为:
\[ dX(t) = \frac{1}{2} X(t) dt + \nabla_X \log p_t(X | X_0) dt + dB_{\mathbb{R}^3}(t). \]其中, 条件得分函数 为:
\[ \nabla_X \log p_t(X | X_0) = \frac{e^{-t/2} X_0 - X}{1 - e^{-t}}. \]训练过程中,我们让神经网络 \( s_\theta(X, t) \) 预测这个得分函数,最终用于反向采样。
Ornstein–Uhlenbeck (OU) 过程
一种特殊的随机过程,用于描述具有均值回归(mean-reverting)特性的随机运动。它是一种扩展的布朗运动(Brownian Motion)。OU 过程是满足以下随机微分方程(SDE, Stochastic Differential Equation)的过程:
\[ dX_t = -\theta (X_t - \mu) dt + \sigma dW_t, \]其中:\( X_t \) 是随时间 \( t \) 变化的随机变量(状态变量)。\( \theta > 0 \) 是 均值回归速率(mean-reverting rate) ,控制 \( X_t \) 向均值 \( \mu \) 回归的速度。\( \mu \) 是长期均值(long-term mean) ,表示 OU 过程最终趋向的值。\( \sigma \) 是 扩散系数(diffusion coefficient) ,控制噪声的强度。\( W_t \) 是标准布朗运动(Wiener 过程)。
OU 过程的解可以用伊藤积分(Itô Integral) 表示:
\[ X_t = X_0 e^{-\theta t} + \mu (1 - e^{-\theta t}) + \sigma \int_0^t e^{-\theta (t-s)} dW_s. \]初始状态 \( X_0 \) 会随着时间指数衰减 \( e^{-\theta t} \) ,逐渐被长期均值 \( \mu \) 吸引。噪声项 \( \sigma dW_t \) 影响 \( X_t \) 的随机波动 ,但不会无限增大,因此 OU 过程具有均值回归特性。
OU 过程的期望值为:
\[ \mathbb{E}[X_t] = X_0 e^{-\theta t} + \mu (1 - e^{-\theta t}). \]随着 \( t \to \infty \),期望值收敛到长期均值:
\[ \lim_{t \to \infty} \mathbb{E}[X_t] = \mu. \]OU 过程的方差为:
\[ \text{Var}[X_t] = \frac{\sigma^2}{2\theta} \left(1 - e^{-2\theta t} \right). \]当 \( t \to \infty \) 时,方差收敛到 稳态方差 :
\[ \lim_{t \to \infty} \text{Var}[X_t] = \frac{\sigma^2}{2\theta}. \]当 \( t \to \infty \),\( X_t \) 服从稳态分布:
\[ X_{\infty} \sim \mathcal{N} \left( \mu, \frac{\sigma^2}{2\theta} \right). \]即 OU 过程最终趋于一个 均值为 \( \mu \),方差为 \( \sigma^2 / (2\theta) \) 的正态分布 。
在SE(3) 扩散模型中,需要在 \( \mathbb{R}^3 \) 上进行扩散建模。为了防止 Cα 原子的平移部分无限扩散,使用均值回归 OU 过程来约束扩散:
\[ dX_t = -\frac{1}{2} X_t dt + dW_t. \]这是一个特殊的 OU 过程,其中:均值回归速率 \( \theta = \frac{1}{2} \)。长期均值 \( \mu = 0 \),表示扩散最终收敛到零均值。方差随时间 \( t \) 变化:
\[ p_t(X | X_0) = \mathcal{N} \left( e^{-t/2} X_0, (1 - e^{-t}) I_3 \right). \]逆扩散过程中,神经网络预测 OU 过程的 得分函数 :
\[ \nabla \log p_t(X | X_0) = \frac{e^{-t/2} X_0 - X}{1 - e^{-t}}. \]该得分用于去噪采样。
过程 | 公式 | 主要特性 |
---|---|---|
布朗运动 | \( dX_t = \sigma dW_t \) | 无均值回归,方差随时间无限增长 |
OU 过程 | \( dX_t = -\theta (X_t - \mu) dt + \sigma dW_t \) | 具有均值回归,方差最终收敛到稳态值 |
从表中我们可以得到: |
- 如果使用布朗运动 ,数据分布会变得无界,难以收敛。
- OU 过程可确保分布不会无限扩散,而是收敛到稳态分布,从而保持数值稳定性。
3.3. SO(3) 上的扩散
在 SO(3) 上,扩散过程通过李群上的布朗运动 进行建模,受控于 SO(3) 上的 拉普拉斯–贝尔特拉米算子 \( \Delta_{SO(3)} \)。
前向扩散过程
SO(3) 上的扩散过程可以由以下随机微分方程(SDE)描述:
\[ dR(t) = R(t) \cdot dB_{SO(3)}(t), \]其中:\( R(t) \in SO(3) \) 是旋转矩阵。\( B_{SO(3)}(t) \) 是 SO(3) 上的布朗运动。该过程的稳态分布为均匀分布 \( U(SO(3)) \)。
扩散过程的转移概率密度由热核(heat kernel)展开给出:
\[ p_t(R | R_0) = \sum_{\ell=0}^{\infty} (2\ell + 1) e^{-\ell(\ell+1)t/2} \chi_\ell(R_0^{-1} R). \]其中:\( \chi_\ell(R) \) 是 SO(3) 上的特征函数(character function)。\( e^{-\ell(\ell+1)t/2} \) 控制扩散程度。
热核(heat kernel)展开
热核(Heat Kernel)是一种在流形或李群上描述扩散过程的基本解。它描述了热方程(heat equation)在某个几何空间上的解,并在扩散模型、概率论、李群分析和物理学中有广泛应用。
在流形 \( M \) 上,热方程的形式为:
\[ \frac{\partial p}{\partial t} = \frac{1}{2} \Delta_M p. \]其中:\( p = p(x, t) \) 表示随时间 \( t \) 变化的概率密度函数(或热分布)。\( \Delta_M \) 是流形 \( M \) 上的 拉普拉斯–贝尔特拉米算子。热核 \( K_t(x, y) \) 是该偏微分方程的基本解,即满足:
\[ \frac{\partial K_t}{\partial t} = \frac{1}{2} \Delta_M K_t. \]直观理解:\( K_t(x, y) \) 表示在时间 \( t \) 内,热量从点 \( y \) 传播到 \( x \) 的概率密度。
在欧几里得空间 \( \mathbb{R}^n \) 上,热方程的解可以用标准高斯分布表示:
\[ K_t(x, y) = \frac{1}{(4\pi t)^{n/2}} \exp\left(-\frac{\|x - y\|^2}{4t}\right). \]这表示在 \( \mathbb{R}^n \) 中,扩散过程的概率密度服从均值为 \( y \),方差随时间 \( t \) 变化的高斯分布。
在李群(如旋转群 \( SO(3) \) )上,由于群的非阿贝尔结构,热核不能简单地用高斯分布表示。这里,我们使用特征函数展开(Character Expansion)来表示热核。SO(3) 上的热核可以展开为:
其中:\( \ell \) 是 SO(3) 的表示索引,对应于不同的角动量状态。\( e^{-\ell(\ell+1)t/2} \) 控制扩散速度,表示较高阶的频率成分随时间衰减得更快。\( \chi_\ell(R) \) 是 SO(3) 的迹函数(character function),描述群元素 \( R \) 之间的关系。
在欧几里得空间中,热核表示一个布朗运动的转移概率。在李群(如 SO(3))上,热核描述的是角动量扩散,它由不同角动量的特征函数贡献。这个展开式表明,SO(3) 上的扩散过程可以通过有限维不可约表示(irreducible representations)进行描述。
SO(3) 旋转群上的热核 \( p_t(R | R_0) \) 描述了布朗运动在 SO(3) 上的转移概率密度。它满足流形上的热方程:
\[ \frac{\partial}{\partial t} p_t(R | R_0) = \frac{1}{2} \Delta_{SO(3)} p_t(R | R_0), \]其中 \( \Delta_{SO(3)} \) 是SO(3) 上的拉普拉斯–贝尔特拉米算子 。
在紧李群上,拉普拉斯–贝尔特拉米算子的本征函数是不可约表示的特征函数(characters)。可以利用特征函数展开(Fourier 级数在群上的推广)得到热核的显式形式。
在紧李群 \( SO(3) \) 上,拉普拉斯–贝尔特拉米算子是由李代数上的左不变向量场 \( X_i \) 生成的:
\[ \Delta_{SO(3)} = X_1^2 + X_2^2 + X_3^2. \]这个算子在SO(3) 的球谐函数(即角动量本征函数)下是对角的,因此我们可以使用它们来展开热核。
\[ f(R) = \sum_{\ell=0}^{\infty} (2\ell + 1) a_\ell \chi_\ell(R), \]其中:\( \ell \) 是角动量量子数,类似于球谐函数的索引。\( \chi_\ell(R) \) 是SO(3) 的迹(character function),描述了旋转矩阵 \( R \) 在不可约表示空间上的迹 :
\[ \chi_\ell(R) = \text{Tr} (D^{(\ell)}(R)). \]其中 \( D^{(\ell)}(R) \) 是 SO(3) 的 Wigner D 矩阵 (即 SO(3) 在角动量 \( \ell \) 维空间中的表示)。系数 \( (2\ell + 1) \) 由 SO(3) 的群结构确定。
对于热核,我们需要求解:
\[ \frac{\partial}{\partial t} p_t(R | R_0) = \frac{1}{2} \Delta_{SO(3)} p_t(R | R_0). \]利用特征函数展开:
\[ p_t(R | R_0) = \sum_{\ell=0}^{\infty} c_\ell(t) \chi_\ell(R_0^{-1} R). \]将其代入热方程:
\[ \frac{\partial}{\partial t} \sum_{\ell=0}^{\infty} c_\ell(t) \chi_\ell(R_0^{-1} R) = \frac{1}{2} \sum_{\ell=0}^{\infty} c_\ell(t) \lambda_\ell \chi_\ell(R_0^{-1} R). \]其中:\( \lambda_\ell = \ell (\ell + 1) \) 是 SO(3) 上的拉普拉斯算子的本征值 ,类似于球谐函数中的本征值。
由于特征函数是正交的,每个项单独满足:
\[ \frac{\partial}{\partial t} c_\ell(t) = \frac{1}{2} \lambda_\ell c_\ell(t). \]解这个微分方程得到:
\[ c_\ell(t) = c_\ell(0) e^{-\frac{1}{2} \lambda_\ell t}. \]根据初始条件(\( t = 0 \) 时 \( p_t(R | R_0) \) 是 \( \delta(R - R_0) \)),可以确定:
\[ c_\ell(0) = (2\ell + 1). \]最终得到:
\[ p_t(R | R_0) = \sum_{\ell=0}^{\infty} (2\ell + 1) e^{-\ell (\ell + 1) t/2} \chi_\ell(R_0^{-1} R). \]这个公式类似于欧几里得空间中高斯分布的 Fourier 变换,只不过 SO(3) 上的“频率分量”由角动量 \( \ell \) 控制。高阶分量 \( \ell \) 在时间 \( t \) 增大时衰减得更快,使得最终分布收敛到均匀分布 \( U(SO(3)) \)。该热核描述了SO(3) 空间上的扩散过程 ,从某个初始旋转 \( R_0 \) 开始,随着时间推移,分布变得更加均匀。
\( t \to \infty \) 时,\( e^{-\ell (\ell + 1) t/2} \) 使得所有高阶分量衰减,最终收敛到均匀分布:
\[ p_\infty(R) = \frac{1}{|SO(3)|} = \frac{1}{8\pi^2}. \]这意味着长时间扩散后,旋转矩阵分布均匀。
逆扩散过程
在逆扩散过程中,我们需要学习条件得分函数 :
\[ \nabla_R \log p_t(R | R_0). \]可以近似表示为:
\[ \nabla_R \log p_t(R | R_0) = R \cdot \frac{\omega(t) \log(R_0^{-1} R) \partial_\omega f(\omega, t)}{f(\omega, t)}, \]其中:\( \omega \) 是旋转角度。\( f(\omega, t) \) 是旋转扩散的热核函数。
神经网络 \( s_\theta(R, t) \) 训练过程中拟合这个得分函数,并在采样过程中用于去噪。
四元数
在代码实现中发现了一个很重要但是之前知听说过但没有涉及过的领域:四元数。只知道四元数是一种被抛弃的表示旋转的方法,但搜索资料发现四元数在深度学习领域应用不少(分子模拟以及机器人学等),于是决定整理一下。
在三维旋转群 \( SO(3) \) 上,旋转通常可以用旋转矩阵或四元数 表示:
- 旋转矩阵 \( R \)(3×3 矩阵) :直接用于 SE(3) 变换,但涉及非线性约束(正交性)。
- 四元数 \( q \)(4 维向量) :更紧凑,仅需 4 维参数(比旋转矩阵少 5 个自由度),并且避免了数值不稳定性(如 Gimbal Lock),这在算法实现中是相当重要的一点。
四元数 \( q \) 定义为:
\[ q = (q_0, q_1, q_2, q_3) = (q_0, \mathbf{q}), \]其中:\( q_0 \) 为实部 ,代表旋转角度信息。\( \mathbf{q} = (q_1, q_2, q_3) \) 为虚部 ,代表旋转轴信息。
旋转矩阵和四元数的关系:
\[ R(q) = I + 2 q_0 [\mathbf{q}]_\times + 2 [\mathbf{q}]_\times^2, \]其中\( [\mathbf{q}]_\times \)是四元数的反对称矩阵:
\[ [\mathbf{q}]_\times = \begin{bmatrix} 0 & -q_3 & q_2 \\ q_3 & 0 & -q_1 \\ -q_2 & q_1 & 0 \end{bmatrix}. \]由于四元数避免了旋转矩阵的正交约束,并且具有更紧凑的表示,因此在神经网络的实现和 扩散过程的数值计算中更容易使用。根据之前的介绍,SO(3)上的扩散涉及
\[ p_t(R | R_0) = \sum_{\ell=0}^{\infty} (2\ell + 1) e^{-\ell (\ell + 1) t/2} \chi_\ell(R_0^{-1} R). \]在数值计算时,直接使用旋转矩阵 \( R \) 进行计算会导致数值不稳定性(如正交性误差),因此:
- 需要采用四元数 \( q \) 来表示旋转,并进行扩散计算。
- 通过四元数的指数映射,将扩散噪声添加到李代数(so(3))上: \[ q(t) = \exp\left( \frac{t}{2} \omega \right) \cdot q_0. \] 其中 \( \omega \) 是 so(3) 的随机噪声。
同样在逆向过程中需要学习旋转的得分函数:
\[ \nabla_R \log p_t(R | R_0). \]在数值计算时,直接优化旋转矩阵可能会导致不稳定性,因此:
- 使用四元数作为神经网络的输出,并进行归一化,以确保旋转的合法性: \[ q' = \frac{q}{\|q\|} \]
- 在每一步的去噪采样时,使用四元数的球面插值(Slerp, Spherical Linear Interpolation)进行平滑更新: \[ q(t - \Delta t) = \text{Slerp} (q_t, \hat{q}, \alpha), \] 其中 \( \hat{q} \) 是预测的四元数方向,\( \alpha \) 是步长。
这样,逆扩散过程可以在 SO(3) 上平滑运行,而不会出现旋转矩阵计算中的奇异性问题。
在训练神经网络时,输入是加噪后的旋转 \( q_t \) ,目标是学习去噪得分 \( \nabla \log p_t(q) \) 。由于四元数是单位球面上的点,需要使用了超球面损失(Hyperspherical Loss)来优化:
\[ L_{\text{rot}} = \mathbb{E}_{q_t} \left[ \| s_\theta(q_t, t) - \nabla_q \log p_t(q) \|^2 \right]. \]这里,\( s_\theta(q_t, t) \) 是神经网络的预测得分。
使用四元数可以避免:1.旋转矩阵的数值不稳定性(正交误差)。2.角度参数化(如欧拉角)导致的 Gimbal Lock(万向锁)。3.过多的自由度(旋转矩阵是 9 维,但只有 3 维自由度)。 四元数的使用,使得 SE(3) 扩散模型在旋转建模方面更加高效和鲁棒。
3.4. SE(3) 上的采样
在 SE(3) 上的采样是通过逆扩散过程进行的,并结合 SO(3) 和 \( \mathbb{R}^3 \) 的两个部分。
SE(3) 逆扩散过程可以写成:
\[ dT(t) = (dR(t), dX(t)) = \big( s_\theta(R, t) dt + dB_{SO(3)}(t), s_\theta(X, t) dt + dB_{\mathbb{R}^3}(t) \big). \]其中:旋转部分 \( R(t) \) 通过 SO(3) 的逆扩散恢复。平移部分 \( X(t) \) 通过 \( \mathbb{R}^3 \) 的逆扩散恢复。
从均匀噪声开始,并按照以下步骤进行采样:
- 初始化 :
- 采样一个随机旋转 \( R(T_F) \sim U(SO(3)) \)。
- 采样一个随机平移 \( X(T_F) \sim \mathcal{N}(0, I_3) \)。
- 逐步逆扩散 :
- 迭代计算: \[ dT(t) = \big( \nabla_R \log p_{T-t}(R), \nabla_X \log p_{T-t}(X) \big) dt + \text{noise}. \]
- 通过 Euler–Maruyama 方法 进行数值积分: \[ R_{t-\delta t} = R_t + \delta t \cdot s_\theta(R_t, t) + \sqrt{\delta t} \cdot \eta_R, \quad \eta_R \sim \mathcal{N}(0, I). \] \[ X_{t-\delta t} = X_t + \delta t \cdot s_\theta(X_t, t) + \sqrt{\delta t} \cdot \eta_X, \quad \eta_X \sim \mathcal{N}(0, I). \]
- 最终生成结构 :
- 当 \( t \to 0 \) 时,我们得到合理的蛋白质结构 \( (R(0), X(0)) \)。
4. 总结
总结:SE(3) 等变的生成扩散模型
SE(3) 扩散模型是一种基于 SE(3) 群的生成扩散模型,适用于蛋白质建模、分子模拟等涉及刚体运动的任务。
核心思想 :
- SE(3) = SO(3) × \( \mathbb{R}^3 \) ,可拆分为 旋转扩散(SO(3)) 和 平移扩散(\( \mathbb{R}^3 \)) ,分别建模蛋白质残基的方向和位置。
- SO(3) 扩散采用李群上的布朗运动,其转移概率由热核展开计算。
- \( \mathbb{R}^3 \) 扩散采用Ornstein–Uhlenbeck (OU) 过程,确保分布不会无限扩散。
四元数的应用 :
- 用于表示旋转,避免 Gimbal Lock 和数值不稳定性。
- 训练时用超球面损失学习得分函数,采样时用Slerp 插值平滑去噪。
逆扩散采样 :
- 通过Euler–Maruyama 方法迭代逆扩散方程,从均匀噪声逐步恢复蛋白质结构。
基本把原理搞清楚了,要开始考虑代码问题了(上万行的代码真是令人头大😵💫)。