计算不同氨基酸的邻居
通过计算氨基酸中Ca原子之间的欧式距离,选取距离最小的前30个氨基酸作为邻居,得到邻居的索引E_idx以及对应的距离D_neighbors和领居掩码mask_neighbors。
#X:[6,217,3],批次大小为6,序列最大长度为217,3表示氨基酸原子Ca的坐标
#mask:[6,217],因为蛋白质序列长度不一,对其进行填充0或者nan,mask中填充氨基酸为0,原有为1
def _dist(X,mask,eps=1E-6):
"""Pairwise euclidean distances"""
mask_2D = torch.unsqueeze(mask,1) * torch.unsqueeze(mask,2)
#[6,1,217]*[6,217,1]=[6,217,217]
#对应元素相乘,找出有效的元素对,有效氨基酸与nan氨基酸对表示为0
dX = torch.unsqueeze(X,1) - torch.unsqueeze(X,2)#[6,217,217,3],计算距离
D = (1. - mask_2D)*10000 + mask_2D* torch.sqrt(torch.sum(dX**2, 3) + eps)
"""
(1. - mask_2D) * 10000:对无效位置(即 mask_2D 中为0的位置)进行处理,将其乘以一个大的数值(这里是10000),
使得这些位置的距离在后续计算中不会被考虑。
对有效位置进行处理,计算两位置之间的欧氏距离。torch.sum(dX**2, 3) 对每对坐标点计算了它们之间的平方欧氏距离,
torch.sqrt() 对这些平方欧氏距离取平方根。eps 是一个小的常数,用于避免在计算平方根时出现除零错误。
"""
#[6,217,1]
D_max, _ = torch.max(D, -1, keepdim=True)#对 D 沿着最后一个维度计算最大值。D_max 存储了每个位置大值
D_adjust = D + (1. - mask_2D) * (D_max+1)
#D_ne:[6,217,30],E_idx[6,217,30]
D_neighbors, E_idx = torch.topk(D_adjust, min(self.top_k, D_adjust.shape[-1]), dim=-1, largest=False)
"""
对D_adjust沿着最后一个维度计算最小的前k个元素。D_neighbors存储了每个位置的最近邻距离,
而E_idx存储了这些最近邻位置的索引
"""
mask_neighbors = gather_edges(mask_2D.unsqueeze(-1), E_idx)
"""
根据邻居索引收集mask_2D的特征,根据function要求,对mask_2D增加维度[6,217,217]->[6,217,217,1]
E_idx[6,217,30],得到mask_neighbors[6,217,30,1],这样mask_nei只保留邻居的mask,是否为0
"""
return D_neighbors, E_idx, mask_neighbors
def gather_edges(edges, neighbor_idx):
# Features [B,N,N,C] at Neighbor indices [B,N,K] => Neighbor features [B,N,K,C]
neighbors = neighbor_idx.unsqueeze(-1).expand(-1, -1, -1, edges.size(-1))
#将neighbor_idx转化为edges相同维度
edge_features = torch.gather(edges, 2, neighbors)
#在 edges 张量的第二个维度上进行收集,使用 neighbors 张量作为索引
return edge_features
计算蛋白质中相对空间编码
这里也是以Ca原子坐标来计算,
灵活的主链特征
包括键角的计算和虚拟二面角
虚拟键角和二面角的计算,主要是根据氨基酸中Ca的坐标
"""
输入X:[6,217,3] Ca原子
"""
def virtual_angles_bond(X,eps=1e-6):
#dX:[6,216,3]
dX = X[:,1:,:]-X[:,:-1,:]#Ca1-Ca0,Ca2-Ca1...Ca{n}-Ca{n-1},相邻ca原子坐标相减
U = F.normalize(dX, dim=-1) #对 dX 进行标准化,即将其转换为单位向量
u_2 = U[:,:-2,:]#[6,214,3]
u_1 = U[:,1:-1,:]
u_0 = U[:,2:,:]
"""
将Ca{n}-Ca{n-1}定义为d{n.n-1},u_*就是提取不同位置的单位向量
"""
n_2 = F.normalize(torch.cross(u_2, u_1), dim=-1)#[6,214,3]
#计算d{n-1.n-2}与d{n.n-1}之间的法向量并正则化
n_1 = F.normalize(torch.cross(u_1, u_0), dim=-1)
"""
计算虚拟键角
"""
cosA = -(u_1 * u_0).sum(-1)#[6,214,3]
#根据计算d{n-1.n-2}与d{n.n-1}夹角的cos值
cosA = torch.clamp(cosA, -1+eps, 1-eps)#限制cos值的范围
A = torch.acos(cosA)#计算cos的角
"""
计算虚拟二面角
"""
cosD = (n_2 * n_1).sum(-1)
cosD = torch.clamp(cosD, -1+eps, 1-eps)
D = torch.sign((u_2 * n_1).sum(-1)) * torch.acos(cosD)#[6,214,3]
#得到主链特征
AD_features = torch.stack((torch.cos(A), torch.sin(A) * torch.cos(D),
torch.sin(A) * torch.sin(D)), 2)#[6,214,3]
#它们解释为单位圆坐标,并将它们表示为单位球面上的点
AD_features = F.pad(AD_features, (0,0,1,2), 'constant', 0)#第二个维度上下各填充一行0
return A,D,AD_feature