药物设计中的SE3等变图神经网络层- EGNN 代码解析

此部分内容介绍了常用在药物设计深度学习中的SE3等变网络层 EGNN。主要对EGNN的代码逻辑、模块进行解析,并介绍其中的SE3等变在模型中的原理。

注:EGNN代码有多种。此部分EGNN代码来源于DiffLinker。其源头为EDM模型,DiffLinker进行了修改。

一、背景知识

在药物设计中,关于3D分子,通常被只考虑原子,并将原子表示为节点。

一个节点(原子)的特征分为两部分,坐标以及节点类型特征(原子种类,电荷,杂化等等)。其中,坐标是等变向量,要求在等变网络中符合SE3操作;节点类型是标量,具有不变性。即常规的MLP即可。

关于化学键:在小分子性质预测或者单纯的小分子生成任务中,化学键通常被利用为边,边的特征包括:化学键类型、键长等。但是在含有口袋的分子生成任务,等变网络往往不将化学键设置为边,而是在等变网络中,通过距离判断,哪些原子之间存在边,并且将距离作为边的特征,这是为了避免复杂的口袋氨基酸的共价网络,简化模型计算。另一方面,在等变图神经网络中,根据更新的坐标,可能会动态更新(创建、取消、更新)节点之间边,更有利于模型的收敛,避免长链路的消息传递。

关于mask,在很多神经网络中,我们经常可以看到各种mask。在EGNN中,我们常见的mask,有:node_mask和edge_mask。在数据中,训练模型时,为了效率都要将多个分子组成batch批次,但是不同的分子原子数量不同,组成批次时就会长短不一。所以,在组成批次时,就会添加一些所有特征包括坐标均为0的maks原子,node_mask就是记录那些位置是真实原子,哪些是mask原子。例如:分子A含有23个原子,分子B含有20个原子,当A和B组成batch批次时,分子B就会在原子列表进行填充3个mask原子到23个原子。那么分子B的node_mask的最后三行为0,其余为1。而原子A没有填充mask原子,因此分子A的node_mask全都为1。将分子A和B的node_mask按照行contact起来,就获得批次的node_mask。

同样,edge_maskt也是类似的。由于在SE3等变网络中,图可能是动态更新的,所以需要edge_mask。这本文的EGNN汇总,图不是动态更新的,因此,为None。

以下是一个常用在分子生成领域的SE3等变图网络层 EGNN代码解析。

二、等变网络EGNN层

在__init__定义函数中,

(1) 包含了h的输入,隐藏,输出维度的转化,包括:in_node_nf, hidden_nf, out_node_nf,对应着self.embedding初始嵌入层和self.embedding_out最后的输出层,两个线性转换层(MLP)。

(2) 定义了等变神经网络,由多个等变模块EquivariantBlock构成。

for i in range(0, n_layers):
  self.add_module("e_block_%d" % i, EquivariantBlock(
    hidden_nf, edge_feat_nf=edge_feat_nf, device=device,
    act_fn=act_fn, n_layers=inv_sublayers,
    attention=attention, norm_diff=norm_diff, tanh=tanh,
    coords_range=coords_range, norm_constant=norm_constant,
    sin_embedding=self.sin_embedding,
    normalization_factor=self.normalization_factor,
    aggregation_method=self.aggregation_method))

注意:norm_diff参数似乎在后面并没有用到。

在forward函数中,

(1) coord2diff 计算组成边两个节点的径向距离。如果self.sin_embedding为True,使用正弦余弦进行嵌入;

(2) self.embedding将h 嵌入到隐藏层维度

(3) 多个EquivariantBlock等变网络模块运算,其中,原子间的径向距离将被作用边的特征;

(4) self.embedding_out 将 h 输出层的维度

EGNN网络层代码如下:

class EGNN(nn.Module):
    def __init__(self, in_node_nf, in_edge_nf, hidden_nf, device='cpu', act_fn=nn.SiLU(), n_layers=3, attention=False,
                 norm_diff=True, out_node_nf=None, tanh=False, coords_range=15, norm_constant=1, inv_sublayers=2,
                 sin_embedding=False, normalization_factor=100, aggregation_method='sum'):
        super(EGNN, self).__init__()
        # 输出 h 的维度,默认与输入相同
        if out_node_nf is None:
            out_node_nf = in_node_nf
        # h 隐藏层的维度
        self.hidden_nf = hidden_nf
        # 设备,cpu/gpu
        self.device = device
        # 等变模块的层数
        self.n_layers = n_layers
        # 距离tanh后的放大倍数,注:tanh激活函数 压缩到(-1.1)
        self.coords_range_layer = float(coords_range/n_layers) if n_layers > 0 else float(coords_range)
        # ???这个参数在后面并没有用到。
        self.norm_diff = norm_diff
        # 归一化因子
        self.normalization_factor = normalization_factor
        # 聚合方法
        self.aggregation_method = aggregation_method

        # 距离的正弦余弦嵌入
        if sin_embedding:
            self.sin_embedding = SinusoidsEmbeddingNew()
            edge_feat_nf = self.sin_embedding.dim * 2
        else:
            self.sin_embedding = None
            edge_feat_nf = 2

        # h,初始嵌入层
        self.embedding = nn.Linear(in_node_nf, self.hidden_nf)
        # h, 输出嵌入层
        self.embedding_out = nn.Linear(self.hidden_nf, out_node_nf)
        # 多个等变网络模块EquivariantBlock组成等变网络
        for i in range(0, n_layers):
            self.add_module("e_block_%d" % i, EquivariantBlock(hidden_nf, edge_feat_nf=edge_feat_nf, device=device,
                                                               act_fn=act_fn, n_layers=inv_sublayers,
                                                               attention=attention, norm_diff=norm_diff, tanh=tanh,
                                                               coords_range=coords_range, norm_constant=norm_constant,
                                                               sin_embedding=self.sin_embedding,
                                                               normalization_factor=self.normalization_factor,
                                                               aggregation_method=self.aggregation_method))
        self.to(self.device)

    def forward(self, h, x, edge_index, node_mask=None, edge_mask=None):
        # Edit Emiel: Remove velocity as input
        # 计算边的两个节点的坐标差异和距离,返回径向距离和归一化的坐标差
        distances, _ = coord2diff(x, edge_index)
        # 对节点径向距离使用正弦余弦嵌入
        if self.sin_embedding is not None:
            distances = self.sin_embedding(distances)
        # h 初始嵌入
        h = self.embedding(h)
        # 等变模块
        for i in range(0, self.n_layers):
            h, x = self._modules["e_block_%d" % i](
                h, x, edge_index, node_mask=node_mask, 
                edge_mask=edge_mask, edge_attr=distances
                )
            # 计划在这里增加BN层,以增加模型训练稳定性

        # Important, the bias of the last linear might be non-zero
        # h 输出
        h = self.embedding_out(h)
        # mask节点 h 置零
        if node_mask is not None:
            h = h * node_mask
        return h, x

三、计算原子间距离-coord2diff

coord2diff函数计算边的两个节点的距离,以及归一化后的坐标差。

def coord2diff(x, edge_index, norm_constant=1):
    row, col = edge_index
    coord_diff = x[row] - x[col]
    radial = torch.sum((coord_diff) ** 2, 1).unsqueeze(1)
    norm = torch.sqrt(radial + 1e-8)
    coord_diff = coord_diff/(norm + norm_constant)
    return radial, coord_diff

四、正/余弦映射-SinusoidsEmbeddingNew

SinusoidsEmbeddingNew 在这里是将距离,通过正/余弦映射到高维度。

基于正弦和余弦的嵌入转换通过将输入数据映射到一个高维的频率空间中,增强了模型的特征表示能力和非线性表达能力,使得模型能够更好地理解和处理复杂的输入数据。

在forward函数中,首先会计算频率数, 按照默认值计算出来的结果是5。

然后,距离x乘以频率[None, :],将对距离末尾新增一维。例如,距离x的维度原来是(batch_size, seq_length),因为,频率的维度self.frequencies是5,因此,x * self.frequencies[None, :]的到转化后的距离是(batch_size, seq_length, 5)。

注意,SinusoidsEmbeddingNew并没有可训练的参数,只是将输入的特征进行了嵌入高维的然后缩放。

SinusoidsEmbeddingNew类的代码如下:

class SinusoidsEmbeddingNew(nn.Module):
    def __init__(self, max_res=15., min_res=15. / 2000., div_factor=4):
        super().__init__()
        # 计算频率数
        self.n_frequencies = int(math.log(max_res / min_res, div_factor)) + 1
        # 频率
        self.frequencies = 2 * math.pi * div_factor ** torch.arange(self.n_frequencies)/max_res
        self.dim = len(self.frequencies) * 2

    def forward(self, x):
        x = torch.sqrt(x + 1e-8)
        emb = x * self.frequencies[None, :].to(x.device)
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb.detach()

五、等变网络模块EquivariantBlock

等变网络模块由多个EquivariantBlock串联组成。在EGNN中,每个EquivariantBlock的输入的是:节点类型特征h, 节点坐标x,边 egde_index,节点掩码node_mask,边掩码 edge_mask,边特征 edge_attr;输出的是:更新后的节点类型特征h, 节点坐标x

以下是等变网络基础模块EquivariantBlock的代码。

class EquivariantBlock(nn.Module):
    def __init__(self, hidden_nf, edge_feat_nf=2, device='cpu', act_fn=nn.SiLU(), n_layers=2, attention=True,
                 norm_diff=True, tanh=False, coords_range=15, norm_constant=1, sin_embedding=None,
                 normalization_factor=100, aggregation_method='sum'):
        super(EquivariantBlock, self).__init__()
        self.hidden_nf = hidden_nf # h 隐藏层维度/输入维度
        self.device = device # GPU or CPU
        self.n_layers = n_layers # 更新节点特征h 卷积层数
        self.coords_range_layer = float(coords_range) # 坐标tanh后缩放
        self.norm_diff = norm_diff # 应该是对坐标进行归一化,但在这部分代码中并未使用
        self.norm_constant = norm_constant # 坐标缩放
        self.sin_embedding = sin_embedding # 距离正弦余弦嵌入
        self.normalization_factor = normalization_factor # 归一化因子
        self.aggregation_method = aggregation_method # 消息聚合方式

        for i in range(0, n_layers):
            # h的消息传递层,不变操作
            self.add_module("gcl_%d" % i, GCL(self.hidden_nf, self.hidden_nf, self.hidden_nf, edges_in_d=edge_feat_nf,
                                              act_fn=act_fn, attention=attention,
                                              normalization_factor=self.normalization_factor,
                                              aggregation_method=self.aggregation_method))
        # 图结构更新,等变操作,更新坐标
        self.add_module("gcl_equiv", EquivariantUpdate(hidden_nf, edges_in_d=edge_feat_nf, act_fn=nn.SiLU(), tanh=tanh,
                                                       coords_range=self.coords_range_layer,
                                                       normalization_factor=self.normalization_factor,
                                                       aggregation_method=self.aggregation_method))
        self.to(self.device)

    def forward(self, h, x, edge_index, node_mask=None, edge_mask=None, edge_attr=None):
        # Edit Emiel: Remove velocity as input
        # 计算边的两节点的距离
        distances, coord_diff = coord2diff(x, edge_index, self.norm_constant)
        # 距离进行正弦/余弦嵌入
        if self.sin_embedding is not None:
            distances = self.sin_embedding(distances)
        # 距离作为边的特征,contact其他边特征
        edge_attr = torch.cat([distances, edge_attr], dim=1)
        # 节点特征h的消息传递 (不变操作)
        for i in range(0, self.n_layers):
            h, _ = self._modules["gcl_%d" % i](h, edge_index, edge_attr=edge_attr, node_mask=node_mask, edge_mask=edge_mask)
        # 更新坐标,会导致更新图,(等变操作)
        x = self._modules["gcl_equiv"](
            h, x, edge_index, coord_diff, edge_attr, node_mask, edge_mask)

        # Important, the bias of the last linear might be non-zero
        # 节点掩码,将不是没有mask原子的信息改为0
        if node_mask is not None:
            h = h * node_mask
        return h, x

在__init__函数中,分别定义了不变操作的GCL层和等变操作的EquivariantUpdate层。以及一些超参数。

属于不变操作的GCL层的超参数有:h特征的输入维度hidden_nf,卷积层数n_layers。

属于的等变操作EquivariantUpdate的超参数:坐标在tanh以后的缩放系数coords_range,是否对坐标进行标准化norm_diff(实际未使用)。

以及其他一些超参数:计算距离时的缩放参数norm_constant,距离是否进行正余弦嵌入sin_embedding。以及在等变和不变操作中都需要的超参数,消息聚合方式aggregation_method, 归一化因子normalization_factor。

在forward函数中,首先根据输入的节点坐标x和edge_index计算及边的距离,并根据self.sin_embedding确定是否进行正弦余弦嵌入,嵌入函数为之前介绍过的SinusoidsEmbeddingNew。

然后,将边的距离与输入的边的特征(默认为None),合并为新的边特征。

随后,通过多层的GCL层(即,"gcl_%d" % i),更新节点特征h。

然后,通过等变操作"gcl_equiv"使用x和新的h,更新原子坐标x。最后,对mask原子的h特征进行置零,避免影响下一步的操作。

六、不变操作的GCL

不变操作的GCL层,与常规的图神经卷积网络相类似,在特定的图结构中,对节点特征和边特征进行消息传递,更新。

不变操作的GCL层代码如下:

class GCL(nn.Module):
    def __init__(self, input_nf, output_nf, hidden_nf, normalization_factor, aggregation_method,
                 edges_in_d=0, nodes_att_dim=0, act_fn=nn.SiLU(), attention=False):
        super(GCL, self).__init__()
        input_edge = input_nf * 2
        self.normalization_factor = normalization_factor # 归一化因子
        self.aggregation_method = aggregation_method # 聚合方法
        self.attention = attention # 是否使用注意力机制
        
        # 边 MLP层
        self.edge_mlp = nn.Sequential(
            nn.Linear(input_edge + edges_in_d, hidden_nf),
            act_fn,
            nn.Linear(hidden_nf, hidden_nf),
            act_fn)

        # 节点 MLP 层
        self.node_mlp = nn.Sequential(
            nn.Linear(hidden_nf + input_nf + nodes_att_dim, hidden_nf),
            act_fn,
            nn.Linear(hidden_nf, output_nf))

        # 自注意层
        if self.attention:
            self.att_mlp = nn.Sequential(
                nn.Linear(hidden_nf, 1),
                nn.Sigmoid())

    def edge_model(self, source, target, edge_attr, edge_mask):
        '''
        更新边特征
        source: 起始节点特征
        target:终止节点特征
        edge_attr:边特征
        edge_mask:边掩码
        '''
        # 边特征
        if edge_attr is None:  # Unused.
            out = torch.cat([source, target], dim=1)
        else:
            out = torch.cat([source, target, edge_attr], dim=1)
        # 更新后边的特征
        mij = self.edge_mlp(out)

        # 自注意更新边特征
        if self.attention:
            att_val = self.att_mlp(mij)
            out = mij * att_val
        else:
            out = mij
        # 边掩码处理
        if edge_mask is not None:
            out = out * edge_mask
        # 输出更新后的边特征out,注意力之前边特征mij
        return out, mij

    def node_model(self, x, edge_index, edge_attr, node_attr):
        '''
        x:节点特征,注意不是坐标
        edge_index:边索引
        edge_attr:边特征
        node_attr:额外的节点特征,默认为None
        '''
        row, col = edge_index # 起始节点编号,终止节点编号
        # 自定义聚合操作,将边信息,聚合到节点中,并根据 aggregation_method 对结果进行归一化处理
        agg = unsorted_segment_sum(edge_attr, row, num_segments=x.size(0),
                                   normalization_factor=self.normalization_factor,
                                   aggregation_method=self.aggregation_method)
        # 合并 节点特征,聚合的边信息,额外的节点特征
        if node_attr is not None:
            agg = torch.cat([x, agg, node_attr], dim=1)
        else:
            agg = torch.cat([x, agg], dim=1)
        # 更新节点特征
        out = x + self.node_mlp(agg)
        # 返回 更新后的节点特征、聚合的边信息
        return out, agg

    def forward(self, h, edge_index, edge_attr=None, node_attr=None, node_mask=None, edge_mask=None):
        row, col = edge_index # 边的起始节点和终止节点
        edge_feat, mij = self.edge_model(h[row], h[col], edge_attr, edge_mask) # 边特征更新,获得边信息edge_feat
        h, agg = self.node_model(h, edge_index, edge_feat, node_attr) # 更新节点特征h, 使用节点特征h, 边信息edge_feat,额外的节点特征
        # mask原子的处理
        if node_mask is not None:
            h = h * node_mask
        return h, mij

在__init__函数中,定义了三个超参数:输入的边特征的维度input_edge(默认是节点特征维度的2倍),归一化因子normalization_factor,是否使用自注意力机制attention。

然后定义了 边信息更新的MLP层,节点信息更新的self.edge_mlp层,以及自注意力self.attention层。

在edge_mode函数中,将起边的起始节点和终止节点,以及边特征,合并为新的边的节点。

然后,使用self.edge_mlp进行更新边的信息。随后进行掩码处理。

输出,更新后边特征out,以及注意力机制前的边特征mij。

在 node_model 中,利用 unsorted_segment_sum将边的信息聚合到起始节点上,然后将聚合后的边信息,节点信息,以及额外的节点信息合并,通过self.node_mlp进行更新节点特征,在通过残差链接。

输出:更新后的节点特征out,以及更新前的节点特征agg。

在forward函数中,完成整个图神经网络的节点信息更新。

首先,调用self.edge_model更新边特征,得到新的边特征edge_feat。self.node_model聚合边信息到节点上,并更新节点特征h。然后进行节点特征的mask处理。

输出:更新后的节点特征h(不变量)和注意力机制前的边信息mij。

七、等变操作EquivariantUpdate层

等变操作EquivariantUpdate层完成坐标部分的更新。其代码如下:

class EquivariantUpdate(nn.Module):
    def __init__(self, hidden_nf, normalization_factor, aggregation_method,
                 edges_in_d=1, act_fn=nn.SiLU(), tanh=False, coords_range=10.0):
        super(EquivariantUpdate, self).__init__()
        self.tanh = tanh # 是否对坐标进行tanh处理
        self.coords_range = coords_range # 坐标范围, 边特征计算出来的距离tanh以后的缩放系数
        input_edge = hidden_nf * 2 + edges_in_d # 边的输入维度
        layer = nn.Linear(hidden_nf, 1, bias=False) # 边特征更新输出层,维度为 1 (距离)
        torch.nn.init.xavier_uniform_(layer.weight, gain=0.001) # 初始化,pytorch默认是Kaiming初始化,为0。
        # 更新边的特征,映射到距离(边特征计算出来的距离)
        self.coord_mlp = nn.Sequential(
            nn.Linear(input_edge, hidden_nf),
            act_fn,
            nn.Linear(hidden_nf, hidden_nf),
            act_fn,
            layer)
        self.normalization_factor = normalization_factor # 归一化因子
        self.aggregation_method = aggregation_method  # 聚合方式

    def coord_model(self, h, coord, edge_index, coord_diff, edge_attr, edge_mask):
        # 坐标更新函数
        row, col = edge_index # 起始节点编号,终止节点编号
        # 合并边特征(起始节点特征、终止节点特征、边特征)
        input_tensor = torch.cat([h[row], h[col], edge_attr], dim=1)
        # 边特征映射到距离,使用tanh后进行缩放
        if self.tanh:
            trans = coord_diff * torch.tanh(self.coord_mlp(input_tensor)) * self.coords_range
        else:
            trans = coord_diff * self.coord_mlp(input_tensor)
        # 边 进行mask处理
        if edge_mask is not None:
            trans = trans * edge_mask
        # 聚合边特征到起始节点上
        agg = unsorted_segment_sum(trans, row, num_segments=coord.size(0),
                                   normalization_factor=self.normalization_factor,
                                   aggregation_method=self.aggregation_method)
        # 更新节点特征(坐标),残差更新
        coord = coord + agg
        # 输出更新以后的坐标
        return coord

    def forward(self, h, coord, edge_index, coord_diff, edge_attr=None, node_mask=None, edge_mask=None):
        # 坐标更新
        coord = self.coord_model(h, coord, edge_index, coord_diff, edge_attr, edge_mask)
        # mask原子,mask 处理
        if node_mask is not None:
            coord = coord * node_mask
        return coord

在__init__函数中, 主要是定义了一个self.coord_mlp层机器超参数。

其中,coords_range为边特征距离在tanh以后的缩放系数;input_edge为输入self.coord_mlp层的维度,包括,起始节点h,终止节点h,以及边距离,该层输出坐标“变化”;layer为self.coord_mlp的输出层,输出将添加在坐标x上的“变化”,因为layer最后输出的维度是1,因此,xyz上的“变化”相同。

在coord_model函数中,首先将起始节点h,终止节点h,以及边的特征(距离)进行特征拼接。

(注:特征拼接不影响SE3等变特性。)随后,利用多层感知机self.coord_mlp对拼接的特征进行线性变换,然后经过tanh处理后乘以节点特征距离coord_diff以及缩放系数coords_range。

在mask边处理以后,通过unsorted_segment_sum将多层感知机线性变换输出的结果(“变化”),聚合到边的起始节点上。

coord_model函数返回,添加“变化”以后的坐标x(即更新后的坐标)。(注:unsorted_segment_sum为聚合信息函数,不影响等变特性和不变特性)。

在forward函数中,则是直接调用coord_model函数更新坐标x,随后进行mask处理。

关于 EquivariantUpdate 的SE3等变特性,EquivariantUpdate 中关键函数是coord_model函数。

coord_model函数先进行特征拼接,然后对特征拼接进行线性变换,最后进行聚合处理。

正如之前所说的,特征的拼接和聚合不影响SE3等变特性,对于节点特征标量h和原子坐标向量x都可以使用。

但是,多层感知机self.coord_mlp不是等变的,不能直接处理坐标x,因此,在EGNN中,self.coord_mlp并没有对坐标向量x进行线性变换,输入不是坐标x,而是节点特征,距离,计算坐标出来“变化”。

因为输入的是标量,输出的“变化”类似于节点特征、距离、角度都是标量,不是向量,添加在坐标向量x上,没有破坏坐标向量x的SE3等变性质。

(注:距离和角度在旋转条件下保持不变,因此属于不变量)总结来说,coord_model函数是SE3等变的,EquivariantUpdate 也是SE3等变的。

关于self.coord_mlp中的等变,self.coord_mlp的组成顺序是,线性变换层nn.Linear、激活函数act_fn(Silu)、线性变换层nn.Linear、激活函数act_fn(Silu)、线性变换层nn.Linear组成。

其中,线性变换层nn.Linear是不满足SE3等变的,因此,self.coord_mlp不满足SE3等变。注:激活函数Silu,不破坏等变性质。

八、边信息聚合到节点函数 unsorted_segment_sum

再GCL层和EquivariantUpdate层都有使用到的边信息聚合到节点函数 unsorted_segment_sum的代码如下。

def unsorted_segment_sum(data, segment_ids, num_segments, normalization_factor, aggregation_method: str):
    """
        模拟 TensorFlow 中的 unsorted_segment_sum 操作。
        它根据提供的segment_ids 对输入数据进行聚合。
        支持 mean 和 sum
    """
    # data 维度是 (n,m) n为节点数,m为特征数
    # segment_ids 为节点ID。(n, )
    # num_segments 样本总数,一般为n
    # normalization_factor 归一化因子
    # aggregation_method 聚合方法
    result_shape = (num_segments, data.size(1)) # 结果维度 (n, m)
    result = data.new_full(result_shape, 0)  # Init empty result tensor. 初始化,填充值0
    segment_ids = segment_ids.unsqueeze(-1).expand(-1, data.size(1)) # segment_ids节点ID扩展到(n,m)
    # scatter_add_的用法,根据segment_ids 将 data 中的值加到 result 张量的相应位置
    result.scatter_add_(0, segment_ids, data)
    # 聚合
    if aggregation_method == 'sum':
        result = result / normalization_factor

    if aggregation_method == 'mean':
        norm = data.new_zeros(result.shape)
        norm.scatter_add_(0, segment_ids, data.new_ones(data.shape))
        norm[norm == 0] = 1  # 将 norm 中为 0 的位置设置为 1,以避免除以 0
        result = result / norm
    # 输出的result的维度为(n,m),其中n来自于num_segments的维度。
    return result

输入的data是边的信息,segment_ids为边信息聚合到的节点ID,num_segments是总的节点数量。

根据num_segments和data的维度,创建输出result的维度。

然后,根据聚合节点ID segment_ids,使用scatter_add_方法直接进行填充。

在进行归一化处理以后,沿着维度dim=1进行平均或者加和处理。

在平均时,为了避免除以0,对全部为0的节点(mask节点)变为1。

九、nn.linear和激活函数的SE3等变

SE3等变要求在旋转和平移下,函数能保持一致。具体来说,输入特征x,在平移和旋转之后变为 R(x)+t,函数f(x)的输出从f(R(x)+t),可以等价为:R(f(x))+t。

一般而言,激活函数,Relu, tanh,Silu等都是逐点操作(即,对每一个数值进行操作),不影响输入特征的等变性质。

对于nn.Linear层( W(x) ),由于初始化的时候,权重矩阵W是随机的,并不能保证 R(W(x)) = W(R(x)) ,因此,nn.Linear不是SE3等变的,不能变换原子坐标x等等变向量。

总结

1. 等变网络EGNN层由多个不变量h更新的GCL层,和一个等变向量(原子坐标x)更新的EquivariantUpdate层组成。

2. GCL层与常规的图神经网络类似,在边和节点上进行消息传递,聚合,更新h。

3. EquivariantUpdate层将起始节点、终止节点特征h,以及边特征(距离)进行线性变换(线性变换层)以后,得到添加在等变向量原子坐标x上的变化。注意,在EquivariantUpdate层并不能保持SE3等变特性,不是直接对坐标进行变换,而是对距离进行变换。利用线性变换层输出的坐标“变化”更新坐标向量x。

4. 激活函数一般而言都是对元素操作,满足SE3等变,但是,nn.Linear不满足SE3等变操作,不能直接对坐标向量x进行变换。

  • 10
    点赞
  • 23
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值