一、TagMol 背景介绍
TAGMoL 是一个基于分子属性条件引导扩散的 3D 分子生成模型,适合在给定靶标蛋白质的情况下,可以生成一系列满足目标特性(分子属性,binding affinity)的候选分子。
TAGMoL 来源于新德里 Molecule AI, 以及美国马萨诸塞大学曼宁信息与计算机科学学院 的Vineeth Dorna 和 D. Subhalingam 为通讯作者的文章:《TAGMoL : Target Aware Gradient-guided Molecule Generation》。文章链接:https://arxiv.org/abs/2406.01650 。该文章于 2024 年 6 月 3 日发表在 arxiv 上。
该模型已经经过了评测,详见 《分子属性梯度引导的3D分子生成扩散模型 TAGMOL - 评测》
分子属性梯度引导的3D分子生成扩散模型 TAGMOL - 评测-CSDN博客
TagMol 模型分为两部分,分子生成模型和梯度引导模型。其中,在原文中,分子生成模型使用的是TargetDiff 模型无需训练,拿原作者提供的 checkpoint 直接使用即可,梯度引导模型为自己训练。在分子生成过程中,使用梯度引导模型,修改 TargetDiff 的生成模型过程,实现特定属性的分子生成。因此,在代码解析部分,我们关心分子生成过程,以及如何训练一个属性引导模型,包含训练梯度引导模型需要的数据预处理过程。(TargetDiff 训练过程则跳过,详见 TargetDiff 部分)。
以下为梯度引导模型及其训练代码解析。
注: 因为原 GitHub 的代码存在 bugs,这里解析的是我们之前修改过的代码,与作者在 GitHub 上提供的稍微不一致。代码的运行环境,更多评测内容,也请看之前的文章《分子属性梯度引导的3D分子生成扩散模型 TAGMOL - 评测》。
二、TagMol 的分子生成代码解析
以多目标分子生成为例,其运行命令为:
python scripts/sample_for_pocket_guided.py \
configs/noise_guide_multi/sampling_guided_qed_0.33_sa_0.33_ba_0.34.yml \
--pdb_path ./3wze/pocket_3wze.pdb \
--result_path outputs_TagMol_3wze_path
在上述命令中,调用 configs/noise_guide_multi/sampling_guided_qed_0.33_sa_0.33_ba_0.34.yml 配置文件,输入的蛋白是 ./3wze/pocket_3wze.pdb, 分子生成结果保存至 outputs_TagMol_3wze_path。
配置文件内容如下,其中,设置了基础模型的 checkpoint,多个引导模型的 checkpoint,及其权重、梯度、类型设置,最后是设置了分子生成采样(采样数量,去噪步数,质量中心等):
# 基础模型
model:
checkpoint: ./pretrained_models/pretrained_diffusion.pt
# 引导模型
guide_models:
- name: qed
checkpoint: ./logs/training_dock_guide_qed_2024_01_06__01_35_21/checkpoints/186000.pt
weight: 0.33
guide_kind: Kd
gradient_scale_cord: 20
gradient_scale_categ: 0.0 #1e-10
clamp_pred_max: 1.0
- name: sa
checkpoint: ./logs/training_dock_guide_sa_2024_01_20__15_38_49/checkpoints/162000.pt
weight: 0.33
guide_kind: Kd
gradient_scale_cord: 5
gradient_scale_categ: 0.0 #1e-10
clamp_pred_max: 1.0
- name: binding_affinity
checkpoint: ./logs/training_dock_guide_2023_12_17__06_23_35/checkpoints/184000.pt
weight: 0.34
guide_kind: Kd
gradient_scale_cord: 2.0
gradient_scale_categ: 0.0 #1e-10
# 采样设置
sample:
seed: 2021
num_samples: 100
num_steps: 1000
pos_only: False
center_pos_mode: protein
sample_num_atoms: prior
因此,我们首先看看 scripts/sample_for_pocket_guided.py 文件的代码。
2.1 __main__ 函数 sample_for_pocket_guided()
首先是 __init__.py 函数。
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('config', type=str)
parser.add_argument('--pdb_path', type=str) # 口袋/蛋白路径
parser.add_argument('--device', type=str, default='cuda:0') # GPU
parser.add_argument('--batch_size', type=int, default=50) # 批次大小
parser.add_argument('--result_path', type=str, default='./outputs_guided_pdb') # 结果保存路径
parser.add_argument('--num_samples', type=int) # 采样分子数
args = parser.parse_args()
# 分子生成
sample_for_pocket_guided(config_path=args.config,
pdb_path=args.pdb_path,
device=args.device,
batch_size=args.batch_size,
result_path=args.result_path)
其中,指定了配置文件,口袋/蛋白文件,输出路径,生成分子数。随后,调用 sample_for_pocket_guided 函数进行分子生成。
sample_for_pocket_guided 函数的代码如下:
def sample_for_pocket_guided(config_path, pdb_path, device, batch_size, result_path):
logger = misc.get_logger('evaluate')
# Load config 加载配置文件
config = misc.load_config(config_path)
logger.info(config)
misc.seed_all(config.sample.seed)
# Load checkpoint 加载主模型的 checkpoint
ckpt = torch.load(config.model.checkpoint, map_location=device)
logger.info(f"Training Config: {ckpt['config']}")
# Transforms 主模型数据转换器
protein_featurizer = trans.FeaturizeProteinAtom()
ligand_atom_mode = ckpt['config'].data.transform.ligand_atom_mode
ligand_featurizer = trans.FeaturizeLigandAtom(ligand_atom_mode)
transform = Compose([
protein_featurizer,
])
# Load model 加载主模型
model = ScorePosNet3D(
ckpt['config'].model,
protein_atom_feature_dim=protein_featurizer.feature_dim,
ligand_atom_feature_dim=ligand_featurizer.feature_dim
).to(device)
# 加载模型权重,strict=False,这意味着模型结构和权重不完全匹配时,不会抛出错误,允许跳过一些不匹配的参数。
model.load_state_dict(ckpt['model'], strict=False if 'train_config' in config.model else True)
model.eval()
for param in model.parameters():
param.requires_grad = False
logger.info(f'Successfully load the model! {config.model.checkpoint}')
# Load Guide Checkpoint
# guide_ckpt = torch.load(config.guide_model.checkpoint, map_location=device)
# guide_ckpt = torch.load(config.model.checkpoint, map_location=device) # wufeil
# logger.info(f"Guide Training Config: {guide_ckpt['config']}")
# Guide Transforms 引导模型的数据转换器
guide_protein_featurizer = GuideFeaturizeProteinAtom()
guide_ligand_featurizer = GuideFeaturizeLigandAtom(mode=ckpt['config']['data']['transform']['ligand_atom_mode'])
# Guide model
# guide_model = get_guide_model(guide_ckpt['config'], guide_protein_featurizer.feature_dim, guide_ligand_featurizer.feature_dim)
# guide_model.load_state_dict(guide_ckpt['model'])
# guide_model = guide_model.to(device)
# guide_model.eval()
# for param in guide_model.parameters():
# param.requires_grad = False
# Guide models 多个引导模型加载到 guide_models 中
### wufeil #######################################################################
guide_models = []
for guide_model_config in config.guide_models:
# Load Guide Checkpoint
guide_ckpt = torch.load(guide_model_config.checkpoint, map_location=device)
print(f"\nGuide Name: {guide_model_config.name}")
logger.info(f"Guide Name: {guide_model_config.name}")
print(f"\nGuide Training Config: {guide_ckpt['config']}")
logger.info(f"Guide Training Config: {guide_ckpt['config']}")
# Guide model
guide_model = get_guide_model(guide_ckpt['config'], guide_protein_featurizer.feature_dim, guide_ligand_featurizer.feature_dim)
guide_model.load_state_dict(guide_ckpt['model'])
guide_model = guide_model.to(device)
guide_model.eval()
guide_models.append(guide_model)
#####################################################################
# Load pocket 加载口袋 pdb 文件,并转换为模型输入
data = pdb_to_pocket_data(pdb_path)
data = transform(data)
# 分子生成,获得生成分子的节点,各节点坐标(以及生成过程的轨迹)
all_pred_pos, all_pred_v, all_pred_pos_traj, all_pred_v_traj, all_pred_v0_traj, all_pred_vt_traj, all_pred_pos0_traj, time_list = sample_guided_diffusion_ligand(
model=model,
# guide_model=guide_model,
guide_models=guide_models,# wufeil
data=data,
guide_configs=config.guide_models, # wufeil
num_samples=config.sample.num_samples,
# kind=config.sample.guide_kind,
# gradient_scale_cord=config.sample.gradient_scale_cord,
# gradient_scale_categ=config.sample.gradient_scale_categ,
batch_size=batch_size,
device=device,
num_steps=config.sample.num_steps,
pos_only=config.sample.pos_only,
center_pos_mode=config.sample.center_pos_mode,
sample_num_atoms=config.sample.sample_num_atoms
)
result = {
'data': data,
'pred_ligand_pos': all_pred_pos,
'pred_ligand_v': all_pred_v,
'pred_ligand_pos_traj': all_pred_pos_traj,
'pred_ligand_v_traj': all_pred_v_traj,
'pred_ligand_v0_traj': all_pred_v0_traj,
'pred_ligand_vt_traj': all_pred_vt_traj,
'pred_pos0_traj': all_pred_pos0_traj
}
logger.info('Sample done!')
# reconstruction 重构分子,并统计重构成功分子数量
gen_mols = []
n_recon_success, n_complete = 0, 0
for sample_idx, (pred_pos, pred_v) in enumerate(zip(all_pred_pos, all_pred_v)):
pred_atom_type = trans.get_atomic_number_from_index(pred_v, mode='add_aromatic')
try:
pred_aromatic = trans.is_aromatic_from_index(pred_v, mode='add_aromatic')
mol = reconstruct.reconstruct_from_generated(pred_pos, pred_atom_type, pred_aromatic)
smiles = Chem.MolToSmiles(mol)
except reconstruct.MolReconsError:
gen_mols.append(None)
continue
n_recon_success += 1
if '.' in smiles:
gen_mols.append(None)
continue
n_complete += 1
gen_mols.append(mol)
result['mols'] = gen_mols
logger.info('Reconstruction done!')
logger.info(f'n recon: {n_recon_success} n complete: {n_complete}')
# 创建结果路径,并复制保存 配置文件及生成的小分子
os.makedirs(result_path, exist_ok=True)
shutil.copyfile(config_path, os.path.join(result_path, 'sample.yml'))
torch.save(result, os.path.join(result_path, f'sample.pt'))
mols_save_path = os.path.join(result_path, f'sdf')
os.makedirs(mols_save_path, exist_ok=True)
for idx, mol in enumerate(gen_mols):
if mol is not None:
sdf_writer = Chem.SDWriter(os.path.join(mols_save_path, f'{idx:03d}.sdf'))
sdf_writer.write(mol)
sdf_writer.close()
logger.info(f'Results are saved in {result_path}')
misc.close_logger(logger)
sample_for_pocket_guided 函数内容分为以下几步:
(1) 加载配置文件;
(2) 加载主模型(ScorePosNet3D,即 TargetDiff 模型) 数据转换器trans.FeaturizeProteinAtom() 及其model(即,TargetDiff);
(3) 加载引导模型的数据转换器 GuideFeaturizeProteinAtom 及多个属性引导模型 guide_models;
(4) pdb_to_pocket_data 函数加载口袋 pdb 文件得到 data;
(5) 调用 sample_guided_diffusion_ligand 函数传入 model 和 guide_models 以及生成配置,进行分子生成。
2.2 加载口袋 pdb 文件(pdb_to_pocket_data)
pdb_to_pocket_data 函数,传入一个口袋 pdb 文件,返回输入 TagMol 模型前的庶数据格式(字典)。将口袋 pdb 文件使用 PDBProtein 类中的 to_dict_atom 方法,按照原子进行提取,提取成为一个字典 pocket_dict,torchify_dict 转化为pytorch tensor。同时,设置一个空的小分子字典 ligand_dict,与 pocket_dict 通过 ProteinLigandData 类的 from_protein_ligand_dicts 方法组合成 data。代码如下:
def pdb_to_pocket_data(pdb_path):
pocket_dict = PDBProtein(pdb_path).to_dict_atom()
data = ProteinLigandData.from_protein_ligand_dicts(
protein_dict=torchify_dict(pocket_dict),
ligand_dict={
'element': torch.empty([0, ], dtype=torch.long),
'pos': torch.empty([0, 3], dtype=torch.float),
'atom_feature': torch.empty([0, 8], dtype=torch.float),
'bond_index': torch.empty([2, 0], dtype=torch.long),
'bond_type': torch.empty([0, ], dtype=torch.long),
}
)
return data
torchify_dict 内容比较简单,将字典中,值为 numpy 矩阵的对象转化为 pytorch tensor,如下:
def torchify_dict(data):
# 将 字典中,值为 numpy 矩阵的对象转化为 pytorch tensor
output = {}
for k, v in data.items():
if isinstance(v, np.ndarray):
output[k] = torch.from_numpy(v)
else:
output[k] = v
return output
2.2.1 解析 pdb 文件 PDBProtein()
PDBProtein 类逐行解析 pdb 文件内容(_parse 方法)。可以返回,原子级别的信息(包括:元素序号、名字、坐标、是否是主链、元素符号、所属的残基编码)、氨基酸级别信息(包括:氨基酸编码、氨基酸质心、CA, C, N, O 原子的坐标)。此外,还可以输入中心坐标和半径输出半径范围内的氨基酸。
详细代码如下:
class PDBProtein(object):
AA_NAME_SYM = {
'ALA': 'A', 'CYS': 'C', 'ASP': 'D', 'GLU': 'E', 'PHE': 'F', 'GLY': 'G', 'HIS': 'H',
'ILE': 'I', 'LYS': 'K', 'LEU': 'L', 'MET': 'M', 'ASN': 'N', 'PRO': 'P', 'GLN': 'Q',
'ARG': 'R', 'SER': 'S', 'THR': 'T', 'VAL': 'V', 'TRP': 'W', 'TYR': 'Y',
}
AA_NAME_NUMBER = {
k: i for i, (k, _) in enumerate(AA_NAME_SYM.items())
}
BACKBONE_NAMES = ["CA", "C", "N", "O"]
def __init__(self, data, mode='auto'):
super().__init__()
if (data[-4:].lower() == '.pdb' and mode == 'auto') or mode == 'path':
with open(data, 'r') as f:
self.block = f.read()
else:
self.block = data
self.ptable = Chem.GetPeriodicTable()
# Molecule properties
self.title = None
# Atom properties
self.atoms = []
self.element = []
self.atomic_weight = []
self.pos = []
self.atom_name = []
self.is_backbone = []
self.atom_to_aa_type = []
# Residue properties
self.residues = []
self.amino_acid = []
self.center_of_mass = []
self.pos_CA = []
self.pos_C = []
self.pos_N = []
self.pos_O = []
self._parse()
def _enum_formatted_atom_lines(self):
for line in self.block.splitlines():
if line[0:6].strip() == 'ATOM':
element_symb = line[76:78].strip().capitalize()
if len(element_symb) == 0:
element_symb = line[13:14]
yield {
'line': line,
'type': 'ATOM',
'atom_id': int(line[6:11]),
'atom_name': line[12:16].strip(),
'res_name': line[17:20].strip(),
'chain': line[21:22].strip(),
'res_id': int(line[22:26]),
'res_insert_id': line[26:27].strip(),
'x': float(line[30:38]),
'y': float(line[38:46]),
'z': float(line[46:54]),
'occupancy': float(line[54:60]),
'segment': line[72:76].strip(),
'element_symb': element_symb,
'charge': line[78:80].strip(),
}
elif line[0:6].strip() == 'HEADER':
yield {
'type': 'HEADER',
'value': line[10:].strip()
}
elif line[0:6].strip() == 'ENDMDL':
break # Some PDBs have more than 1 model.
def _parse(self):
# Process atoms
residues_tmp = {}
for atom in self._enum_formatted_atom_lines():
if atom['type'] == 'HEADER':
self.title = atom['value'].lower()
continue
self.atoms.append(atom)
atomic_number = self.ptable.GetAtomicNumber(atom['element_symb'])
next_ptr = len(self.element)
self.element.append(atomic_number)
self.atomic_weight.append(self.ptable.GetAtomicWeight(atomic_number))
self.pos.append(np.array([atom['x'], atom['y'], atom['z']], dtype=np.float32))
self.atom_name.append(atom['atom_name'])
self.is_backbone.append(atom['atom_name'] in self.BACKBONE_NAMES)
self.atom_to_aa_type.append(self.AA_NAME_NUMBER[atom['res_name']])
chain_res_id = '%s_%s_%d_%s' % (atom['chain'], atom['segment'], atom['res_id'], atom['res_insert_id'])
if chain_res_id not in residues_tmp:
residues_tmp[chain_res_id] = {
'name': atom['res_name'],
'atoms': [next_ptr],
'chain': atom['chain'],
'segment': atom['segment'],
}
else:
assert residues_tmp[chain_res_id]['name'] == atom['res_name']
assert residues_tmp[chain_res_id]['chain'] == atom['chain']
residues_tmp[chain_res_id]['atoms'].append(next_ptr)
# Process residues
self.residues = [r for _, r in residues_tmp.items()]
for residue in self.residues:
sum_pos = np.zeros([3], dtype=np.float32)
sum_mass = 0.0
for atom_idx in residue['atoms']:
sum_pos += self.pos[atom_idx] * self.atomic_weight[atom_idx]
sum_mass += self.atomic_weight[atom_idx]
if self.atom_name[atom_idx] in self.BACKBONE_NAMES:
residue['pos_%s' % self.atom_name[atom_idx]] = self.pos[atom_idx]
residue['center_of_mass'] = sum_pos / sum_mass
# Process backbone atoms of residues
for residue in self.residues:
self.amino_acid.append(self.AA_NAME_NUMBER[residue['name']])
self.center_of_mass.append(residue['center_of_mass'])
for name in self.BACKBONE_NAMES:
pos_key = 'pos_%s' % name # pos_CA, pos_C, pos_N, pos_O
if pos_key in residue:
getattr(self, pos_key).append(residue[pos_key])
else:
getattr(self, pos_key).append(residue['center_of_mass'])
def to_dict_atom(self):
return {
'element': np.array(self.element, dtype=np.long),
'molecule_name': self.title,
'pos': np.array(self.pos, dtype=np.float32),
'is_backbone': np.array(self.is_backbone, dtype=np.bool),
'atom_name': self.atom_name,
'atom_to_aa_type': np.array(self.atom_to_aa_type, dtype=np.long)
}
def to_dict_residue(self):
return {
'amino_acid': np.array(self.amino_acid, dtype=np.long),
'center_of_mass': np.array(self.center_of_mass, dtype=np.float32),
'pos_CA': np.array(self.pos_CA, dtype=np.float32),
'pos_C': np.array(self.pos_C, dtype=np.float32),
'pos_N': np.array(self.pos_N, dtype=np.float32),
'pos_O': np.array(self.pos_O, dtype=np.float32),
}
def query_residues_radius(self, center, radius, criterion='center_of_mass'):
center = np.array(center).reshape(3)
selected = []
for residue in self.residues:
distance = np.linalg.norm(residue[criterion] - center, ord=2)
print(residue[criterion], distance)
if distance < radius:
selected.append(residue)
return selected
def query_residues_ligand(self, ligand, radius, criterion='center_of_mass'):
selected = []
sel_idx = set()
# The time-complexity is O(mn).
for center in ligand['pos']:
for i, residue in enumerate(self.residues):
distance = np.linalg.norm(residue[criterion] - center, ord=2)
if distance < radius and i not in sel_idx:
selected.append(residue)
sel_idx.add(i)
return selected
def residues_to_pdb_block(self, residues, name='POCKET'):
block = "HEADER %s\n" % name
block += "COMPND %s\n" % name
for residue in residues:
for atom_idx in residue['atoms']:
block += self.atoms[atom_idx]['line'] + "\n"
block += "END\n"
return block
2.2.2 组合蛋白和小分子的字典 ProteinLigandData()
比较简单,合并蛋白和小分子的字典。
class ProteinLigandData(Data):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@staticmethod
def from_protein_ligand_dicts(protein_dict=None, ligand_dict=None, **kwargs):
# 将蛋白和小分子的字典合并成为一个新的字典,
# 蛋白的 key 前面增加 protein_
# ligand 的 key 前面增加 ligand_
instance = ProteinLigandData(**kwargs)
if protein_dict is not None:
for key, item in protein_dict.items():
instance['protein_' + key] = item
if ligand_dict is not None:
for key, item in ligand_dict.items():
instance['ligand_' + key] = item
instance['ligand_nbh_list'] = {i.item(): [j.item() for k, j in enumerate(instance.ligand_bond_index[1])
if instance.ligand_bond_index[0, k].item() == i]
for i in instance.ligand_bond_index[0]}
return instance
def __inc__(self, key, value, *args, **kwargs):
if key == 'ligand_bond_index':
return self['ligand_element'].size(0)
else:
return super().__inc__(key, value)
2.3 引导模型数据转换 GuideFeaturizeProteinAtom()
引导模型数据转化分为两块,分别对应蛋白和小分子的数据转换,对应类 GuideFeaturizeProteinAtom() 和 GuideFeaturizeLigandAtom() 。来源于:
from utils.transforms_prop import FeaturizeProteinAtom as GuideFeaturizeProteinAtom
from utils.transforms_prop import FollowerFeaturizeLigandAtom as GuideFeaturizeLigandAtom
整个引导模型的数据转换代码,如下(在 sample_for_pocket_guided 函数中):
# Guide Transforms 引导模型的数据转换器
guide_protein_featurizer = GuideFeaturizeProteinAtom()
guide_ligand_featurizer = GuideFeaturizeLigandAtom(mode=ckpt['config']['data']['transform']['ligand_atom_mode'])
2.3.1 引导模型的蛋白数据特征化
由 GuideFeaturizeProteinAtom() 类实现,主要功能为将 protein_data 字典组合成为节点特征,以及返回节点特征维度。
蛋白中原子的特征包括:元素(H, C, N, O, S, Se)、氨基酸类型、是否是主链。
代码如下:
class FeaturizeProteinAtom(object):
def __init__(self):
super().__init__()
self.atomic_numbers = torch.LongTensor([1, 6, 7, 8, 16, 34]) # H, C, N, O, S, Se
self.max_num_aa = 20
@property
def feature_dim(self):
# 节点特征维度
return self.atomic_numbers.size(0) + self.max_num_aa + 1
def __call__(self, data: ProteinLigandData):
# 将 data(字典)组合成节点特征矩阵
element = data.protein_element.view(-1, 1) == self.atomic_numbers.view(1, -1) # (N_atoms, N_elements)
amino_acid = F.one_hot(data.protein_atom_to_aa_type, num_classes=self.max_num_aa)
is_backbone = data.protein_is_backbone.view(-1, 1).long()
x = torch.cat([element, amino_acid, is_backbone], dim=-1)
data.protein_atom_feature = x
return data
2.3.2 引导模型的小分子数据特征化
将小分子数据 ligand_data 特征化,由 dict 转化为节点特征矩阵,以及提取小分子特征矩阵的维度。
小分子可以兼容的元素为:H, C, N, O, F, P, S, Cl ,小分子的节点特征为:元素、原子芳香性。(注意,小分子和蛋白的节点特征不一致)
代码如下:
class FollowerFeaturizeLigandAtom(object):
'''
将小分子的 ligand_data 字典转化为特征矩阵
'''
def __init__(self, mode='basic'):
super().__init__()
self.atomic_numbers = torch.LongTensor([1, 6, 7, 8, 9, 15, 16, 17]) # H, C, N, O, F, P, S, Cl
assert mode in ['basic', 'add_aromatic', 'full']
self.mode = mode
# self.n_degree = torch.LongTensor([0, 1, 2, 3, 4, 5]) # 0 - 5
# self.n_num_hs = 6 # 0 - 5
@property
def feature_dim(self):
if self.mode == 'basic':
return len(MAP_ATOM_TYPE_ONLY_TO_INDEX)
elif self.mode == 'add_aromatic':
return len(MAP_ATOM_TYPE_AROMATIC_TO_INDEX)
else:
return len(MAP_ATOM_TYPE_FULL_TO_INDEX)
def __call__(self, data: ProteinLigandData):
element_list = data['ligand_element'] # 元素列表
aromatic_list = data['ligand_atom_feature'][:,1] # 原子芳香性列表
x = [get_index(e, None, a, self.mode) for e, a in zip(element_list, aromatic_list)]
x = torch.tensor(x)
x = F.one_hot(x, num_classes=self.feature_dim) # 节点特正 onr-hot
data.ligand_atom_feature_full = x # 小分子节点特征矩阵
return data
2.4 加载梯度引导模型 get_guide_model()
在 sample_for_pocket_guided 函数中,加载梯度引导模型的相关代码是:(注:此部分为我们修改过后的代码)
guide_models = []
for guide_model_config in config.guide_models:
# Load Guide Checkpoint
guide_ckpt = torch.load(guide_model_config.checkpoint, map_location=device)
print(f"\nGuide Name: {guide_model_config.name}")
logger.info(f"Guide Name: {guide_model_config.name}")
print(f"\nGuide Training Config: {guide_ckpt['config']}")
logger.info(f"Guide Training Config: {guide_ckpt['config']}")
# Guide model
guide_model = get_guide_model(guide_ckpt['config'], guide_protein_featurizer.feature_dim, guide_ligand_featurizer.feature_dim)
guide_model.load_state_dict(guide_ckpt['model'])
guide_model = guide_model.to(device)
guide_model.eval()
guide_models.append(guide_model)
通过 for 循环,get_guide_model 函数逐个加载梯度引导模型到梯度引导模型列表中 guide_models。
get_guide_model() 实际上对应 get_model() 函数, 来源于:
from scripts.property_prediction.inference import get_model as get_guide_model
get_model() 函数传入模型的类型(默认是 egnn ),蛋白节点特征维度、小分子特征维度,返回梯度引导模型 DockGuideNet3D。
代码如下:
def get_model(config, protein_atom_feat_dim, ligand_atom_feat_dim):
'''
实例化 梯度引导模型 DockGuideNet3D
'''
if config.model.model_type == 'egnn':
## the noisy guide model trained on diffused inputs
model = DockGuideNet3D(
config.model,
protein_atom_feature_dim=protein_atom_feat_dim,
ligand_atom_feature_dim=ligand_atom_feat_dim
)
return model
else:
raise NotImplementedError
2.5 引导分子采样 sample_guided_diffusion_ligand()
在完成口袋、主模型、梯度引导模型的加载,然后使用 sample_guided_diffusion_ligand 函数进行批次化的引导分子采样(包含了数据的批次化)。
sample_guided_diffusion_ligand 函数源自于 ./scipts/sample_multi_guided_diffusion.py 文件,其代码如下:
def sample_guided_diffusion_ligand(model, guide_models, guide_configs, data, num_samples, batch_size=1, device='cuda:0',
num_steps=None, pos_only=False, center_pos_mode='protein',
sample_num_atoms='prior'): # batch_size=16 改成 4
'''
逐个批次进行分子生成,每个批次内:
1. 采样生成分子的原子数;
2. 初始化小分子的位置和节点类型
3. 使用主模型的 model.sample_multi_guided_diffusion,进行采样
4. 生成分子坐标、节点类型、轨迹的解批次
'''
# 将主函数和梯度引导函数设置为评估模式
model.eval()
for guide_model in guide_models:
guide_model.eval()
# 多个分子的坐标和节点类型
all_pred_pos, all_pred_v = [], []
# 多个分子的生成轨迹(即:每个时间 t 下的坐标和节点类型)
all_pred_pos_traj, all_pred_pos0_traj, all_pred_v_traj = [], [], []
all_pred_v0_traj, all_pred_vt_traj = [], []
time_list = []
# 批次数
num_batch = int(np.ceil(num_samples / batch_size))
current_i = 0
for i in tqdm(range(num_batch)): # 逐批次生成
# 批次中生成的分子数
n_data = batch_size if i < num_batch - 1 else num_samples - batch_size * (num_batch - 1)
# 复制data, 组合成批次数据,
# FOLLOW_BATCH 特征在批处理中将保持单独的批次信息,确保合并后仍能区分原来属于哪个图结构的数据。
# 这里跟踪的是('protein_element', 'ligand_element', 'ligand_bond_type')
# 会产生一个 protein_element_batch 的张量形状为 (批次中节点数,),标记每一个节点属于哪一个图(分子)。
batch = Batch.from_data_list([data.clone() for _ in range(n_data)], follow_batch=FOLLOW_BATCH).to(device)
t1 = time.time()
with torch.no_grad():
batch_protein = batch.protein_element_batch
# 各生成分子的原子数 batch_ligand
if sample_num_atoms == 'prior':
# 通过口袋形状,根据预定义好的口袋跨度-分子原子数分布中,采样各生成分子的原子数
pocket_size = atom_num.get_space_size(batch.protein_pos.detach().cpu().numpy()) # 口袋跨度
ligand_num_atoms = [atom_num.sample_atom_num(pocket_size).astype(int) for _ in range(n_data)] # 采样原子数
batch_ligand = torch.repeat_interleave(torch.arange(n_data), torch.tensor(ligand_num_atoms)).to(device) # batch 口袋原子数
elif sample_num_atoms == 'range':
# 按照范围设定
ligand_num_atoms = list(range(current_i + 1, current_i + n_data + 1))
batch_ligand = torch.repeat_interleave(torch.arange(n_data), torch.tensor(ligand_num_atoms)).to(device)
elif sample_num_atoms == 'ref':
# 按照当前分子的原子数
batch_ligand = batch.ligand_element_batch
ligand_num_atoms = scatter_sum(torch.ones_like(batch_ligand), batch_ligand, dim=0).tolist()
else:
raise ValueError
# init ligand pos,初始化小分子坐标
# 使用 scatter_mean计算 batch.protein_pos 的加权均值
# 按照 batch_protein 权重计算,舍弃小分子, 以及沿着批次维度计算),
# batch_protein 和 batch_ligand 都是 mask,标记 batch 中哪里是蛋白,哪里是小分子
center_pos = scatter_mean(batch.protein_pos, batch_protein, dim=0)
batch_center_pos = center_pos[batch_ligand]
init_ligand_pos = batch_center_pos + torch.randn_like(batch_center_pos) # 蛋白质心加上随机值
# init ligand v 初始化小分子节点类型,均为0初始化,随机概率初始化
if pos_only:
# 节点类型不初始化,均为0
init_ligand_v = batch.ligand_atom_feature_full
else:
# 节点类型随机概率初始化
uniform_logits = torch.zeros(len(batch_ligand), model.num_classes).to(device)
init_ligand_v = log_sample_categorical(uniform_logits)
# 主模型+引导模型进行分子生成
r = model.sample_multi_guided_diffusion(
guide_models=guide_models, # 引导模型
guide_configs=guide_configs, # 引导模型配置
n_data=n_data, # 批次中分子数
device=device,
protein_pos=batch.protein_pos, # 蛋白坐标
protein_v=batch.protein_atom_feature.float(), # 蛋白节点类型
batch_protein=batch_protein, # 蛋白的 mask
init_ligand_pos=init_ligand_pos, # 初始化小分子坐标
init_ligand_v=init_ligand_v, # 初始化小分子 节点类型
batch_ligand=batch_ligand, # 小分子 mask
num_steps=num_steps, # 去噪/扩散步数
pos_only=pos_only,
center_pos_mode=center_pos_mode # 中心模式, 蛋白
)
ligand_pos, ligand_v, ligand_pos_traj, ligand_v_traj = r['pos'], r['v'], r['pos_traj'], r['v_traj']
ligand_v0_traj, ligand_vt_traj, ligand_pos0_traj = r['v0_traj'], r['vt_traj'], r['pos0_traj']
# unbatch pos,坐标 ligand_pos 解批次化,并转为 numpy, 矩阵形状为 num_samples * [num_atoms_i, 3]
ligand_cum_atoms = np.cumsum([0] + ligand_num_atoms)
ligand_pos_array = ligand_pos.cpu().numpy().astype(np.float64)
all_pred_pos += [ligand_pos_array[ligand_cum_atoms[k]:ligand_cum_atoms[k + 1]] for k in
range(n_data)] # num_samples * [num_atoms_i, 3]
# 坐标去噪轨迹 ligand_pos_traj 解批次
all_step_pos = [[] for _ in range(n_data)]
for p in ligand_pos_traj: # step_i
p_array = p.cpu().numpy().astype(np.float64)
for k in range(n_data):
all_step_pos[k].append(p_array[ligand_cum_atoms[k]:ligand_cum_atoms[k + 1]])
all_step_pos = [np.stack(step_pos) for step_pos in
all_step_pos] # num_samples * [num_steps, num_atoms_i, 3]
all_pred_pos_traj += [p for p in all_step_pos]
# all_step_pos0 解批次
all_step_pos0 = [[] for _ in range(n_data)]
for p in ligand_pos0_traj: # step_i
p_array = p.cpu().numpy().astype(np.float64)
for k in range(n_data):
all_step_pos0[k].append(p_array[ligand_cum_atoms[k]:ligand_cum_atoms[k + 1]])
all_step_pos0 = [np.stack(step_pos) for step_pos in
all_step_pos0] # num_samples * [num_steps, num_atoms_i, 3]
all_pred_pos0_traj += [p for p in all_step_pos0]
# unbatch v 节点类型 ligand_v 解批次
ligand_v_array = ligand_v.cpu().numpy()
all_pred_v += [ligand_v_array[ligand_cum_atoms[k]:ligand_cum_atoms[k + 1]] for k in range(n_data)]
# 节点类型去噪轨迹 ligand_v_traj 解批次
all_step_v = unbatch_v_traj(ligand_v_traj, n_data, ligand_cum_atoms)
all_pred_v_traj += [v for v in all_step_v]
if not pos_only:
all_step_v0 = unbatch_v_traj(ligand_v0_traj, n_data, ligand_cum_atoms)
all_pred_v0_traj += [v for v in all_step_v0]
all_step_vt = unbatch_v_traj(ligand_vt_traj, n_data, ligand_cum_atoms)
all_pred_vt_traj += [v for v in all_step_vt]
t2 = time.time()
time_list.append(t2 - t1)
current_i += n_data
return all_pred_pos, all_pred_v, all_pred_pos_traj, all_pred_v_traj, all_pred_v0_traj, all_pred_vt_traj, all_pred_pos0_traj, time_list
sample_guided_diffusion_ligand 函数的主要内容为:逐个批次进行分子生成,每个批次内,
1. 采样生成分子的原子数;sample_num_atoms 参数决定生成分子的原子数来源, prior 根据预设好的口袋跨度-原子数关系采样;range 按照预设值的范围进行采样;ref 根据参考分子的原数量决定。
2. 初始化小分子的位置和节点类型:坐标初始化为蛋白中心+随机数初始化;节点类型使用随机概率初始化或者不初始化(均为0)。
3. 使用主模型的 model.sample_multi_guided_diffusion 函数,进行采样
4. 生成分子坐标、节点类型、轨迹的解批次
其中,关键的 model.sample_multi_guided_diffusion 函数是,主模型+梯度引导模型,根据初始化的小分子坐标、节点类型以及蛋白的坐标和节点类型,进行分子生成。这个 sample_multi_guided_diffusion 函数是作者在 TargetDiff 模型 (ScorePosNet3D)上新增的函数,同样需新增的函数还有 sample_guided_diffusion 。除了这两个函数,原来 TargetDiff 的代码保持不变。
2.6 TagetDiff +多引导模型采样(ScorePosNet3D 中sample_multi_guided_diffusion)
主要由 ScorePosNet3D 模型的 sample_multi_guided_diffusion 方法实现。(注:由于 ScorePosNet3D 是一个完整的 3D 分子生成扩散模型,包含了训练损失,生成,以及其他添加噪音等一系列的函数,比较复杂。在这里仅仅对多种梯度模型引导的分子生成函数 sample_multi_guided_diffusion 及其支持函数进行解析。)
在 sample_multi_guided_diffusion 函数中,首先是基本准备,包括:检查梯度引导模型数量与配置数量是否相同;坐标中心归 0,调整小分子和蛋白的坐标;初始化 V_t, V_0, X_t, X_0, V_t-1, X_t-1等去噪过程记录列表;初始化时间 t 列表。代码为:
# 检查 梯度引导模型数量与配置数量是否相同
assert len(guide_models) == len(guide_configs), f"guide_models and guide_configs must have the same length"
# 去噪步数
if num_steps is None:
num_steps = self.num_timesteps
num_graphs = batch_protein.max().item() + 1
# 整体中心归 0,调整小分子和蛋白的坐标
protein_pos, init_ligand_pos, offset = center_pos(
protein_pos, init_ligand_pos, batch_protein, batch_ligand, mode=center_pos_mode)
pos_traj, v_traj = [], []
v0_pred_traj, vt_pred_traj, pos0_traj = [], [], []
ligand_pos, ligand_v = init_ligand_pos, init_ligand_v
# time sequence 时间序列,从大到小
time_seq = list(reversed(range(self.num_timesteps - num_steps, self.num_timesteps)))
然后进行逐步去噪过程。主要涉及:
(1)某个时刻 t 下,输入 V_t, X_t 到神经网络预测 V_0,X_0, 然后计算先验 V_t-1, X_t-1,提取真实的分子坐标即 X_0。 涉及代码:
# 逐步去噪
for i in tqdm(time_seq, desc='sampling', total=len(time_seq)):
# 某 t 时刻
t = torch.full(size=(num_graphs,), fill_value=i, dtype=torch.long, device=protein_pos.device)
# 输入 V_t,神经网络预测 V_0
with torch.no_grad():
preds = self(
protein_pos=protein_pos,
protein_v=protein_v,
batch_protein=batch_protein,
init_ligand_pos=ligand_pos,
init_ligand_v=ligand_v,
batch_ligand=batch_ligand,
time_step=t
)
# Compute posterior mean and variance
if self.model_mean_type == 'noise':
# 如果神经网路输出的是噪音,提取真实分子坐标
pred_pos_noise = preds['pred_ligand_pos'] - ligand_pos
pos0_from_e = self._predict_x0_from_eps(xt=ligand_pos, eps=pred_pos_noise, t=t, batch=batch_ligand)
v0_from_e = preds['pred_ligand_v']
raise NotImplementedError
elif self.model_mean_type == 'C0':
# 如果神经网络输出的是真实分子坐标
pos0_from_e = preds['pred_ligand_pos']
v0_from_e = preds['pred_ligand_v']
else:
raise ValueError
# 输入 V_t, X_t 和神经网络预测的 V_0, X_0 计算先验 V_t-1, X_t-1
# 坐标
pos_model_mean = self.q_pos_posterior(x0=pos0_from_e, xt=ligand_pos, t=t, batch=batch_ligand)
# 坐标的对数方差,即不确定性
pos_log_variance = extract(self.posterior_logvar, t, batch_ligand)
其中,关于坐标归 0 的函数 center_pos,逐个计算批次中每个蛋白-小分子体系的坐标均值,代码如下:
def center_pos(protein_pos, ligand_pos, batch_protein, batch_ligand, mode='protein'):
if mode == 'none':
# 不进行任何中心化操作,即坐标不归0
offset = 0.
pass
elif mode == 'protein':
# 计算 protein_pos 在每个批次的均值 offset。
# scatter_mean 函数根据 batch_protein 张量逐个计算均值。
offset = scatter_mean(protein_pos, batch_protein, dim=0) # 计算每个批次的蛋白中心位置
protein_pos = protein_pos - offset[batch_protein]
ligand_pos = ligand_pos - offset[batch_ligand]
else:
raise NotImplementedError
return protein_pos, ligand_pos, offset
(2) 输入 X_t-1, V_t-1 由梯度引导模型计算梯度。
注意,每一个引导模型都有自己的权重,以及梯度的缩放因子,即每个梯度引导模型产生的梯度都会受到 权重*梯度缩放因子的影响。对于多梯度引导模型来说,梯度权重,梯度缩放因子都是超参数,根据不同的模型需要不同设置。(显然,这是一个超参数问题,实际应用需要超参数检索???)
各个梯度引导模型计算梯度的函数是 get_gradients_guide()。将在梯度引导模型部分详细解析。
涉及代码:
## Classifier Guidance for the Diffusion 梯度引导
ligand_pos_grad, ligand_v_grad = None, None
# 逐个计算不同引导模型的梯度
for guide_model, guide_config in zip(guide_models, guide_configs):
guide_weight = guide_config.weight
gradient_scale_cord = guide_config.gradient_scale_cord # 坐标梯度引导强度
gradient_scale_categ = guide_config.gradient_scale_categ # 节点类型梯度引导强度
clamp_pred_min = guide_config.get("clamp_pred_min", None) # 最小梯度值,默认 None
clamp_pred_max = guide_config.get("clamp_pred_max", None) # 最大梯度值,默认为 None
kind = guide_config.guide_kind # 梯度引导类型
kind = torch.tensor([KMAP[kind]]*n_data).to(device) # 引导类型编号 {'Ki': 1, 'Kd': 2, 'IC50': 3}
# 调用梯度引导模型计算梯度
curr_ligand_pos_grad, curr_ligand_v_grad = guide_model.get_gradients_guide(
protein_pos=protein_pos,
protein_atom_feature=protein_v,
ligand_pos=ligand_pos,
ligand_atom_feature=F.one_hot(ligand_v,self.num_classes).float(), # 小分子的节点类型从概率转换为 noe-hot
batch_protein=batch_protein,
batch_ligand=batch_ligand,
# output_kind=kind,
time_step=t,
pos_only=False,
clamp_pred_min=clamp_pred_min,
clamp_pred_max=clamp_pred_max,
)
# NOTE: extra terms to be applied later
# 坐标类型的梯度
if ligand_pos_grad is None:
# 第一个梯度引导模型,未有梯度积累 ligand_pos_grad 为 None
ligand_pos_grad = guide_weight * gradient_scale_cord * curr_ligand_pos_grad
else:
ligand_pos_grad += (guide_weight * gradient_scale_cord * curr_ligand_pos_grad)
# 节点类型的梯度
if gradient_scale_categ != 0:
if ligand_v_grad is None:
ligand_v_grad = guide_weight * gradient_scale_categ * curr_ligand_v_grad
else:
ligand_v_grad += (guide_weight * gradient_scale_categ * curr_ligand_v_grad)
(3) 使用梯度更新坐标 X_t-1和节点类型 V_t-1。
需要注意的是,每一个坐标的梯度(移动程度)是根据坐标的确定性概率 (pos_log_variance) 进行了调整。节点特征部分,文章和代码中默认是不能使用梯度更新的,但是保留相应的代码。节点类型的梯度更新的是概率,而不是神经网络输入的对数概率。(神经网络输入和输出的节点类型都是对数概率)
涉及代码:
## update the coordinates based on the computed gradient after scaling
# 更新节点类型,按照不确定进行更新
ligand_pos_grad_update = ligand_pos_grad * ((0.5 * pos_log_variance).exp())
# 更新坐标位置
pos_model_mean = pos_model_mean - ligand_pos_grad_update
# (注:按照超参数设置,节点类型是不更新的,因此文章设置了报错)
# 节点类型的梯度只做了展示,但是没有更新到节点中
assert ligand_v_grad is None, "Non-zero value for `gradient_scale_categ` is experimental and not part of the paper."
## end of classifier guidance
# no noise when t == 0
nonzero_mask = (1 - (t == 0).float())[batch_ligand].unsqueeze(-1) # ligand 的节点 mask, 如果 t=0则全部为0,即不更新
# 随机添加噪音到坐标中
ligand_pos_next = pos_model_mean + nonzero_mask * (0.5 * pos_log_variance).exp() * torch.randn_like(
ligand_pos)
# 将 X_t-1 设置为新的 X_t 准备下一个 t-1 时刻的去噪
ligand_pos = ligand_pos_next
if not pos_only:
# 如果扩散模型不是只对坐标,节点类型也参与扩散,默认是 pos_only 为 False
## v0^ predicted from current time step, 神经网络预测的 V_0
log_ligand_v_recon = F.log_softmax(v0_from_e, dim=-1)
## vt,V_t
log_ligand_v = index_to_log_onehot(ligand_v, self.num_classes)
## posterior probability of vt-1 given v0^ and vt
# 计算 V_t-1, 是一个 log 概率
log_model_prob = self.q_v_posterior(log_ligand_v_recon, log_ligand_v, t, batch_ligand)
## Classifer update
## update the categories based on gradient guidance
# 如果节点类型由梯度,参数默认是没有梯度
if ligand_v_grad is not None:
## heuristic-driven. no significant mathematical justification
# prob [0.8, 0.2]
# guidance [-0.9, -0.7]
# [-0.1, -0.5]
# [0.4, 0.0]
# [1, 0]
# applying gradient scale & weights have been done after the get_gradients_guide() step itself
# 梯度是添加到概率中的,不是对数概率
updated_model_prob = torch.exp(log_model_prob) - ligand_v_grad
# 每一行的最小值归零,使得 updated_model_prob 中的所有值都非负,并保持每一行的相对差异不变
updated_model_prob = updated_model_prob - torch.min(updated_model_prob,axis=-1).values.unsqueeze(1)
# 每一行概率归一化
updated_model_prob = (updated_model_prob.T / updated_model_prob.sum(axis=1)).T
# 更新后的概率转化为 对数概率
log_model_prob = torch.log(updated_model_prob)
## end of classifier guidance
## sample vt-1 from the probabilities 根据节点对数概率采样节点概率(节点特征)
ligand_v_next = log_sample_categorical(log_model_prob)
# import pdb;pdb.set_trace()
v0_pred_traj.append(log_ligand_v_recon.clone().cpu())
vt_pred_traj.append(log_model_prob.clone().cpu())
# 将 V_t-1 设置为新的 V_t 准备下一个 t-1 时刻的去噪
ligand_v = ligand_v_next
(4) 记录过程
当然,要恢复小分子的中心
# 记录过程
ori_ligand_pos0 = pos0_from_e + offset[batch_ligand] # 还原小分子位置,把坐标中心加回来
ori_ligand_pos = ligand_pos + offset[batch_ligand]
pos0_traj.append(ori_ligand_pos0.clone().cpu())
pos_traj.append(ori_ligand_pos.clone().cpu())
v_traj.append(ligand_v.clone().cpu())
(5) 函数返回
经过全部 t 时刻的去噪,既可返回生成的小分子的坐标和节点类型。
ligand_pos = ligand_pos + offset[batch_ligand] # 还原小分子位置,把坐标中心加回来
return {
'pos': ligand_pos, # 生成的坐标
'v': ligand_v, # 生成的节点类型
'pos_traj': pos_traj, # 每个 t 时刻,X_t-1, 包含了梯度调整后的结果
'pos0_traj': pos0_traj, # 每个 t 时刻,神经网络预测的 X_0
'v_traj': v_traj, # 每个 t 时刻,V_t-1, 包含了梯度调整后的结果
'v0_traj': v0_pred_traj, # 每个 t 时刻,神经网络预测的 V_0
'vt_traj': vt_pred_traj # 每个 t 时刻,V_t
}
(6) ScorePosNet3D 模型的 sample_multi_guided_diffusion 方法的完整代码:
def sample_multi_guided_diffusion(self, guide_models, guide_configs, n_data, device, protein_pos, protein_v, batch_protein,
init_ligand_pos, init_ligand_v, batch_ligand,
num_steps=None, center_pos_mode=None, pos_only=False):
# 检查 梯度引导模型数量与配置数量是否相同
assert len(guide_models) == len(guide_configs), f"guide_models and guide_configs must have the same length"
# 去噪步数
if num_steps is None:
num_steps = self.num_timesteps
num_graphs = batch_protein.max().item() + 1
# 中心归 0,调整小分子和蛋白的坐标
protein_pos, init_ligand_pos, offset = center_pos(
protein_pos, init_ligand_pos, batch_protein, batch_ligand, mode=center_pos_mode)
pos_traj, v_traj = [], []
v0_pred_traj, vt_pred_traj, pos0_traj = [], [], []
ligand_pos, ligand_v = init_ligand_pos, init_ligand_v
# time sequence 时间序列,从大到小
time_seq = list(reversed(range(self.num_timesteps - num_steps, self.num_timesteps)))
# 逐步去噪
for i in tqdm(time_seq, desc='sampling', total=len(time_seq)):
# 某 t 时刻
t = torch.full(size=(num_graphs,), fill_value=i, dtype=torch.long, device=protein_pos.device)
# 输入 V_t,神经网络预测 V_0
with torch.no_grad():
preds = self(
protein_pos=protein_pos,
protein_v=protein_v,
batch_protein=batch_protein,
init_ligand_pos=ligand_pos,
init_ligand_v=ligand_v,
batch_ligand=batch_ligand,
time_step=t
)
# Compute posterior mean and variance
if self.model_mean_type == 'noise':
# 如果神经网路输出的是噪音,提取真实分子坐标
pred_pos_noise = preds['pred_ligand_pos'] - ligand_pos
pos0_from_e = self._predict_x0_from_eps(xt=ligand_pos, eps=pred_pos_noise, t=t, batch=batch_ligand)
v0_from_e = preds['pred_ligand_v']
raise NotImplementedError
elif self.model_mean_type == 'C0':
# 如果神经网络输出的是真实分子坐标
pos0_from_e = preds['pred_ligand_pos']
v0_from_e = preds['pred_ligand_v']
else:
raise ValueError
# 输入 V_t, X_t 和神经网络预测的 V_0, X_0 计算先验 V_t-1, X_t-1
# 坐标
pos_model_mean = self.q_pos_posterior(x0=pos0_from_e, xt=ligand_pos, t=t, batch=batch_ligand)
# 坐标的对数方差,即不确定性
pos_log_variance = extract(self.posterior_logvar, t, batch_ligand)
## Classifier Guidance for the Diffusion 梯度引导
ligand_pos_grad, ligand_v_grad = None, None
# 逐个计算不同引导模型的梯度
for guide_model, guide_config in zip(guide_models, guide_configs):
guide_weight = guide_config.weight
gradient_scale_cord = guide_config.gradient_scale_cord # 坐标梯度引导强度
gradient_scale_categ = guide_config.gradient_scale_categ # 节点类型梯度引导强度
clamp_pred_min = guide_config.get("clamp_pred_min", None) # 最小梯度值,默认 None
clamp_pred_max = guide_config.get("clamp_pred_max", None) # 最大梯度值,默认为 None
kind = guide_config.guide_kind # 梯度引导类型
kind = torch.tensor([KMAP[kind]]*n_data).to(device) # 引导类型编号 {'Ki': 1, 'Kd': 2, 'IC50': 3}
# ligand_pos_grad = guide_model.get_gradients_guide(
# protein_pos=protein_pos,
# protein_atom_feature=protein_v,
# ligand_pos=ligand_pos,
# ligand_atom_feature=F.one_hot(ligand_v,self.num_classes).float(),
# batch_protein=batch_protein,
# batch_ligand=batch_ligand,
# output_kind=kind,
# pos_only=True
# )
# 调用梯度引导模型计算梯度
curr_ligand_pos_grad, curr_ligand_v_grad = guide_model.get_gradients_guide(
protein_pos=protein_pos,
protein_atom_feature=protein_v,
ligand_pos=ligand_pos,
ligand_atom_feature=F.one_hot(ligand_v,self.num_classes).float(), # 小分子的节点类型从概率转换为 noe-hot
batch_protein=batch_protein,
batch_ligand=batch_ligand,
# output_kind=kind,
time_step=t,
pos_only=False,
clamp_pred_min=clamp_pred_min,
clamp_pred_max=clamp_pred_max,
)
# NOTE: extra terms to be applied later
# 坐标类型的梯度
if ligand_pos_grad is None:
# 第一个梯度引导模型,未有梯度积累 ligand_pos_grad 为 None
ligand_pos_grad = guide_weight * gradient_scale_cord * curr_ligand_pos_grad
else:
ligand_pos_grad += (guide_weight * gradient_scale_cord * curr_ligand_pos_grad)
# 节点类型的梯度
if gradient_scale_categ != 0:
if ligand_v_grad is None:
ligand_v_grad = guide_weight * gradient_scale_categ * curr_ligand_v_grad
else:
ligand_v_grad += (guide_weight * gradient_scale_categ * curr_ligand_v_grad)
## update the coordinates based on the computed gradient after scaling
# 更新节点类型,按照不确定进行更新
ligand_pos_grad_update = ligand_pos_grad * ((0.5 * pos_log_variance).exp())
# 更新坐标位置
pos_model_mean = pos_model_mean - ligand_pos_grad_update
# (注:按照超参数设置,节点类型是不更新的,因此文章设置了报错)
# 节点类型的梯度只做了展示,但是没有更新到节点中
assert ligand_v_grad is None, "Non-zero value for `gradient_scale_categ` is experimental and not part of the paper."
## end of classifier guidance
# no noise when t == 0
nonzero_mask = (1 - (t == 0).float())[batch_ligand].unsqueeze(-1) # ligand 的节点 mask, 如果 t=0则全部为0,即不更新
# 随机添加噪音到坐标中
ligand_pos_next = pos_model_mean + nonzero_mask * (0.5 * pos_log_variance).exp() * torch.randn_like(
ligand_pos)
# 将 X_t-1 设置为新的 X_t 准备下一个 t-1 时刻的去噪
ligand_pos = ligand_pos_next
if not pos_only:
# 如果扩散模型不是只对坐标,节点类型也参与扩散,默认是 pos_only 为 False
## v0^ predicted from current time step, 神经网络预测的 V_0
log_ligand_v_recon = F.log_softmax(v0_from_e, dim=-1)
## vt,V_t
log_ligand_v = index_to_log_onehot(ligand_v, self.num_classes)
## posterior probability of vt-1 given v0^ and vt
# 计算 V_t-1, 是一个 log 概率
log_model_prob = self.q_v_posterior(log_ligand_v_recon, log_ligand_v, t, batch_ligand)
## Classifer update
## update the categories based on gradient guidance
# 如果节点类型由梯度,参数默认是没有梯度
if ligand_v_grad is not None:
## heuristic-driven. no significant mathematical justification
# prob [0.8, 0.2]
# guidance [-0.9, -0.7]
# [-0.1, -0.5]
# [0.4, 0.0]
# [1, 0]
# applying gradient scale & weights have been done after the get_gradients_guide() step itself
# 梯度是添加到概率中的,不是对数概率
updated_model_prob = torch.exp(log_model_prob) - ligand_v_grad
# 每一行的最小值归零,使得 updated_model_prob 中的所有值都非负,并保持每一行的相对差异不变
updated_model_prob = updated_model_prob - torch.min(updated_model_prob,axis=-1).values.unsqueeze(1)
# 每一行概率归一化
updated_model_prob = (updated_model_prob.T / updated_model_prob.sum(axis=1)).T
# 更新后的概率转化为 对数概率
log_model_prob = torch.log(updated_model_prob)
## end of classifier guidance
## sample vt-1 from the probabilities 根据节点对数概率采样节点概率(节点特征)
ligand_v_next = log_sample_categorical(log_model_prob)
# import pdb;pdb.set_trace()
v0_pred_traj.append(log_ligand_v_recon.clone().cpu())
vt_pred_traj.append(log_model_prob.clone().cpu())
# 将 V_t-1 设置为新的 V_t 准备下一个 t-1 时刻的去噪
ligand_v = ligand_v_next
# 记录过程
ori_ligand_pos0 = pos0_from_e + offset[batch_ligand] # 还原小分子位置,把坐标中心加回来
ori_ligand_pos = ligand_pos + offset[batch_ligand]
pos0_traj.append(ori_ligand_pos0.clone().cpu())
pos_traj.append(ori_ligand_pos.clone().cpu())
v_traj.append(ligand_v.clone().cpu())
ligand_pos = ligand_pos + offset[batch_ligand] # 还原小分子位置,把坐标中心加回来
return {
'pos': ligand_pos, # 生成的坐标
'v': ligand_v, # 生成的节点类型
'pos_traj': pos_traj, # 每个 t 时刻,X_t-1, 包含了梯度调整后的结果
'pos0_traj': pos0_traj, # 每个 t 时刻,神经网络预测的 X_0
'v_traj': v_traj, # 每个 t 时刻,V_t-1, 包含了梯度调整后的结果
'v0_traj': v0_pred_traj, # 每个 t 时刻,神经网络预测的 V_0
'vt_traj': vt_pred_traj # 每个 t 时刻,V_t
}
三、训练梯度引导模型
虽然在分子生成时,可以同时使用多个梯度引导模型对 TargetDiff 的分子生成过程进行引导,但是在训练梯度引导模型时,是一个个单独训练的。如 2.1 部分描述的,每一个梯度引导模型,实际上都是一个 DockGuideNet3D() 模型,是基于不同的拟合值训练出来的。
训练梯度引导模型的命令是 (以 Binding Affinity 为例):
python scripts/train_dock_guide.py \
configs/training_dock_guide.yml
关于训练梯度引导模型的配置文件,如下:
data:
name: pl_dock_guide
path: ./data/guide/crossdocked_v1.1_rmsd1.0_pocket10
split: ./data/guide/crossdocked_pocket10_pose_split_dock_guide.pt
index_path: ./data/guide/crossdocked_v1.1_rmsd1.0_pocket10/index.pkl
transform:
ligand_atom_mode: add_aromatic
random_rot: False
model:
model_mean_type: C0 # ['noise', 'C0']
beta_schedule: sigmoid
beta_start: 1.e-7
beta_end: 2.e-3
v_beta_schedule: cosine
v_beta_s: 0.01
num_diffusion_timesteps: 1000
loss_v_weight: 100.
sample_time_method: symmetric # ['importance', 'symmetric']
time_emb_dim: 4
time_emb_mode: sin
center_pos_mode: protein
node_indicator: True
model_type: egnn
num_blocks: 1
num_layers: 9
hidden_dim: 128
n_heads: 16
edge_feat_dim: 4 # edge type feat
num_r_gaussian: 20
knn: 32 # !
num_node_types: 8
act_fn: relu
norm: True
cutoff_mode: knn # [radius, none]
ew_net_type: global # [r, m, none]
num_x2h: 1
num_h2x: 1
r_max: 10.
x2h_out_fc: False
sync_twoup: False
update_x: False
train:
seed: 2021
batch_size: 16
num_workers: 4
n_acc_batch: 1
max_iters: 200000
val_freq: 2000
pos_noise_std: 0.1
max_grad_norm: 8.0
bond_loss_weight: 1.0
optimizer:
type: adam
lr: 0.001 #5.e-4
weight_decay: 0
beta1: 0.95
beta2: 0.999
scheduler:
type: plateau
factor: 0.6
patience: 10
min_lr: 1.e-6
其中,data 部分中的 name 为数据集的名字,会据此选择不同的数据加载模类; path 则为原始蛋白-小分子复合体系的保存位置,整个路径下的每一个文件夹是一个体系,默认是 ./data/guide/crossdocked_v1.1_rmsd1.0_pocket10。index_path 则是用于生成数据集的体系列表,是一个pkl文件,一般为 index.pkl。
在配置文件中,model 部分是模型的配置,其中包括了 噪音生成器的配置(如:model_mean_type,beta_schedule, num_diffusion_timesteps 等,注意,噪音调度器的配置要与主模型 TargetDiff 一致。 ),也包括了神经网络的配置(如: model_type, num_blocks 等)。train 部分则是一些训练梯度引导模型的超参数,例如:batch_size, seed, optimizer, scheduler 等。
3.1 __main__ 函数
整个__main__ 函数比较简单。这里就简单介绍一下。
(1) 加载参数,加载配置文件,创建日志路径,复制模型配置文件,复制代码。如下代码:
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('config', type=str)
parser.add_argument('--device', type=str, default='cuda')
parser.add_argument('--logdir', type=str, default='./logs') # 日志路径
parser.add_argument('--tag', type=str, default='')
parser.add_argument('--train_report_iter', type=int, default=200) # 输出间隔
args = parser.parse_args()
# Load configs 加载配置文件
config = misc.load_config(args.config)
config_name = os.path.basename(args.config)[:os.path.basename(args.config).rfind('.')]
misc.seed_all(config.train.seed)
# Logging,加载配置文件路径设置
log_dir = misc.get_new_log_dir(args.logdir, prefix=config_name, tag=args.tag)
ckpt_dir = os.path.join(log_dir, 'checkpoints') # checkpoint 保存路径
os.makedirs(ckpt_dir, exist_ok=True)
vis_dir = os.path.join(log_dir, 'vis')
os.makedirs(vis_dir, exist_ok=True)
logger = misc.get_logger('train', log_dir)
writer = torch.utils.tensorboard.SummaryWriter(log_dir)
logger.info(args)
logger.info(config)
shutil.copyfile(args.config, os.path.join(log_dir, os.path.basename(args.config))) # 复制配置文件
shutil.copytree('./models', os.path.join(log_dir, 'models')) # 复制梯度引导模型代码
(2) 初始化数据转换器
trans.FeaturizeProteinAtom() 和trans.FeaturizeLigandAtom() 可以别分别将蛋白和小分子的字典数据,转化为tensor矩阵),trans.RandomRotation() 设置坐标旋转。将初始化后的 trans.RandomRotation() ,trans.FeaturizeProteinAtom() 和trans.FeaturizeLigandAtom() 组成的 transform 输入到 get_dataset() 函数中,加载数据,生成字典 dataset 和 subsets (训练集和验证集)。
代码如下:
# Transforms 输入数据转换器
protein_featurizer = trans.FeaturizeProteinAtom()
ligand_featurizer = trans.FeaturizeLigandAtom(config.data.transform.ligand_atom_mode)
transform_list = [
protein_featurizer,
ligand_featurizer,
trans.FeaturizeLigandBond(),
]
# 坐标随机旋转
if config.data.transform.random_rot:
transform_list.append(trans.RandomRotation())
transform = Compose(transform_list)
# Datasets and loaders 加载数据,对数据进行转化
logger.info('Loading dataset...')
# 从口袋 pdb 文件。小分子 sdf 文件到数据字典。
dataset, subsets = get_dataset(
config=config.data,
transform=transform,
# heavy_only=config.data.heavy_only
index_path=config.data.index_path
)
train_set, val_set = subsets['train'], subsets['test']
logger.info(f'Training: {len(train_set)} Validation: {len(val_set)}')
get_dataset() 函数来自于 datasets/__init__.py 文件。因为这里训练的是梯度引导模型,因此,调用的是 PocketLigandPairDockGuideDataset() 类实现数据加载。
代码如下:
def get_dataset(config, *args, **kwargs):
name = config.name
root = config.path
if name == 'pl':
dataset = PocketLigandPairDataset(root, *args, **kwargs)
elif name == 'pdbbind':
dataset = PDBBindDataset(root, *args, **kwargs)
elif name == 'pl_dock_guide':
# 梯度引导模型加载数据
dataset = PocketLigandPairDockGuideDataset(root, *args, **kwargs)
else:
raise NotImplementedError('Unknown dataset: %s' % name)
if 'split' in config:
# 根据 ./data/guide/crossdocked_pocket10_pose_split_dock_guide.pt 划分训练集和验证集
split = torch.load(config.split)
subsets = {k: Subset(dataset, indices=v) for k, v in split.items()}
return dataset, subsets
else:
return dataset
关于 PocketLigandPairDockGuideDatase() 类的详细解析。与很多基于 EGNN 或者 SE3 的神经网络相同,PocketLigandPairDockGuideDatase() 类非常相似。唯一的区别是在_process函数中将 vina score 也作为一个 data 的值,后期将用于训练梯度引导模型的标签(可以替换为任何值作为标签)。
PocketLigandPairDockGuideDatase() 类的代码如下:
class PocketLigandPairDockGuideDataset(Dataset):
def __init__(self, raw_path, transform=None, version='final', index_path=None, processed_path=None):
'''
raw_path:原始数据的路径。
transform:可选的转换函数,用于在数据加载时对数据进行处理。
version:数据处理的版本,影响处理后的文件名。
index_path:指向索引文件的路径,如果为 None,则从 raw_path 中自动寻找 index.pkl 文件。
processed_path:处理后的数据存储路径,如果为 None,则根据 raw_path 和 version 自动生成。
self.db:用于存储 LMDB 数据库的连接。
self.keys:存储数据库中的键(用于索引数据)。
如果处理后的数据文件不存在,调用 _process() 方法处理原始数据并生成处理后的数据文件
'''
super().__init__()
self.raw_path = raw_path.rstrip('/')
# 数据集体系名字列表文件
self.index_path = os.path.join(self.raw_path, 'index.pkl') if index_path is None else index_path
# 如果数据路径中没有 lmdb 文件,则生成 lmdb 文件名,准备重新生成
if processed_path is None:
self.processed_path = os.path.join(os.path.dirname(self.raw_path),
os.path.basename(self.raw_path) + f'_processed_dock_guide_{version}.lmdb')
# 打印输出数据集的文件名
print(f'processed_path: {self.processed_path}')
else:
self.processed_path = processed_path
self.transform = transform # 从字典到矩阵的转换器
self.db = None
self.keys = None
if not os.path.exists(self.processed_path):
# 如果处理后的数据文件 lmdb 文件不存在,则 self._process() 重新生成
print(f'{self.processed_path} does not exist, begin processing data')
self._process()
def _connect_db(self):
"""
用于建立只读的 LMDB 数据库连接
最大数据容量为 10 GB
Establish read-only database connection
"""
assert self.db is None, 'A connection has already been opened.'
self.db = lmdb.open(
self.processed_path,
map_size=10*(1024*1024*1024), # 10GB 最大数据容量
create=False,
subdir=False,
readonly=True,
lock=False,
readahead=False,
meminit=False,
)
with self.db.begin() as txn:
self.keys = list(txn.cursor().iternext(values=False))
def _close_db(self):
self.db.close()
self.db = None
self.keys = None
def _process(self):
'''
用于处理原始数据并将其存储到 LMDB 数据库中
'''
db = lmdb.open(
self.processed_path,
map_size=10*(1024*1024*1024), # 10GB
create=True,
subdir=False,
readonly=False, # Writable
)
# 加载构建数据集的列表
with open(self.index_path, 'rb') as f:
index = pickle.load(f)
num_skipped = 0 # 跳过的体系总数
with db.begin(write=True, buffers=True) as txn:
# 逐个体系生成一个 data 字典,传到 lmdb 中
for i, (pocket_fn, ligand_fn, _, _, vina, props) in enumerate(tqdm(index)):
if pocket_fn is None: continue
try:
# data_prefix = '/data/work/jiaqi/binding_affinity'
data_prefix = self.raw_path
pocket_dict = PDBProtein(os.path.join(data_prefix, pocket_fn)).to_dict_atom() # 提取蛋白 pdb 文件到字典
ligand_dict = parse_sdf_file(os.path.join(data_prefix, ligand_fn)) # 提取小分子文件到字典
# 合并蛋白和小分子的字典到 data 字典
data = ProteinLigandData.from_protein_ligand_dicts(
protein_dict=torchify_dict(pocket_dict),
ligand_dict=torchify_dict(ligand_dict),
)
data.protein_filename = pocket_fn
data.ligand_filename = ligand_fn
data.vina_dock = vina # vina 打分,标签
data = data.to_dict() # avoid torch_geometric version issue
data.update(props)
txn.put(
key=str(i).encode(),
value=pickle.dumps(data)
)
except:
num_skipped += 1
print('Skipping (%d) %s' % (num_skipped, ligand_fn, ))
continue
db.close()
def __len__(self):
if self.db is None:
self._connect_db()
return len(self.keys)
def __getitem__(self, idx):
data = self.get_ori_data(idx)
if self.transform is not None:
data = self.transform(data)
return data
def get_ori_data(self, idx):
if self.db is None:
self._connect_db()
key = self.keys[idx]
data = pickle.loads(self.db.begin().get(key))
data = ProteinLigandData(**data)
data.id = idx
assert data.protein_pos.size(0) > 0
return data
(5) 将训练集和测试集,加载到 dataloader 中(批次化)
代码简单,如下:
需要关注一下 FOLLOW_BATCH 参数, 此参数标记字典中,哪些数据 (key) 还需要记录批次中属于哪个图的 key,默认是 FOLLOW_BATCH = ('protein_element', 'ligand_element', 'ligand_bond_type',) 。
# follow_batch = ['protein_element', 'ligand_element']
collate_exclude_keys = ['ligand_nbh_list']
# 训练集迭代器 (批次化)
train_iterator = utils_train.inf_iterator(DataLoader(
train_set,
batch_size=config.train.batch_size,
shuffle=True,
num_workers=config.train.num_workers,
follow_batch=FOLLOW_BATCH, # 字典中,还需要记录批次中属于哪个图的 key
exclude_keys=collate_exclude_keys # 不需要的 key
))
# 验证集 dataloader(将矩阵合在一起批次化)
val_loader = DataLoader(val_set, config.train.batch_size, shuffle=False,
follow_batch=FOLLOW_BATCH, exclude_keys=collate_exclude_keys)
(6) 加载梯度引导模型 DockGuideNet3D,训练优化器,梯度调节器。代码如下:
# Model 梯度引导模型
logger.info('Building model...')
model = DockGuideNet3D(
config.model,
protein_atom_feature_dim=protein_featurizer.feature_dim,
ligand_atom_feature_dim=ligand_featurizer.feature_dim
).to(args.device)
# print(model)
print(f'protein feature dim: {protein_featurizer.feature_dim} ligand feature dim: {ligand_featurizer.feature_dim}')
logger.info(f'# trainable parameters: {misc.count_parameters(model) / 1e6:.4f} M')
# Optimizer and scheduler 优化器和调度器
optimizer = utils_train.get_optimizer(config.train.optimizer, model)
scheduler = utils_train.get_scheduler(config.train.scheduler, optimizer)
(7) 批次训练函数,包括:调用梯度引导模型计算损失,梯度裁剪,记录log。代码如下:
# 批次的训练,n_acc_batch 个批次更新一次权重
def train(it):
model.train()
optimizer.zero_grad()
avg_loss = 0
for _ in range(config.train.n_acc_batch):
batch = next(train_iterator).to(args.device)
protein_noise = torch.randn_like(batch.protein_pos) * config.train.pos_noise_std
gt_protein_pos = batch.protein_pos + protein_noise
# 计算损失
loss = model.get_loss(
protein_pos=gt_protein_pos,
protein_v=batch.protein_atom_feature.float(),
batch_protein=batch.protein_element_batch,
ligand_pos=batch.ligand_pos,
ligand_v=batch.ligand_atom_feature_full,
batch_ligand=batch.ligand_element_batch,
dock=batch[config.train.get("target", "vina_dock")] # TODO: show warning if target not in config
)
loss = loss / config.train.n_acc_batch
avg_loss += loss
loss.backward()
# 梯度剪裁
orig_grad_norm = clip_grad_norm_(model.parameters(), config.train.max_grad_norm)
# 更新梯度
optimizer.step()
# log 批次间隔
if it % args.train_report_iter == 0:
logger.info(
'[Train] Iter %d | Loss %.6f | Lr: %.6f | Grad Norm: %.6f' % (
it, avg_loss, optimizer.param_groups[0]['lr'], orig_grad_norm
)
)
writer.add_scalar(f'train/loss', loss, it)
writer.add_scalar('train/lr', optimizer.param_groups[0]['lr'], it)
writer.add_scalar('train/grad', orig_grad_norm, it)
writer.flush()
(7) 验证函数,与训练函数很相似,略过,代码如下:
# 验证函数
def validate(it):
# fix time steps
sum_loss, sum_n = 0, 0
all_preds = []
all_gts = []
with torch.no_grad():
model.eval()
for batch in tqdm(val_loader, desc='Validate'):
batch = batch.to(args.device)
batch_size = batch.num_graphs
for t in np.linspace(0, model.num_timesteps - 1, 10).astype(int):
time_step = torch.tensor([t] * batch_size).to(args.device)
loss = model.get_loss(
protein_pos=batch.protein_pos,
protein_v=batch.protein_atom_feature.float(),
batch_protein=batch.protein_element_batch,
ligand_pos=batch.ligand_pos,
ligand_v=batch.ligand_atom_feature_full,
batch_ligand=batch.ligand_element_batch,
time_step=time_step,
dock=batch[config.train.get("target", "vina_dock")] # TODO: show warning if target not in config
)
sum_loss += float(loss) * batch_size
sum_n += batch_size
avg_loss = sum_loss / sum_n
# 根据平均损失更新调度器
if config.train.scheduler.type == 'plateau':
scheduler.step(avg_loss)
elif config.train.scheduler.type == 'warmup_plateau':
scheduler.step_ReduceLROnPlateau(avg_loss)
else:
scheduler.step()
# 记录
logger.info(
'[Validate] Iter %05d | Loss %.6f' % (
it, avg_loss
)
)
writer.add_scalar('val/loss', avg_loss, it)
writer.flush()
return avg_loss
(8) 在 try 语句中,尝试进行模型的训练、验证并每个间隔保存新的模型checkpoint。代码如下:
try:
best_loss, best_iter = None, None
# 逐次迭代
for it in range(1, config.train.max_iters + 1):
# with torch.autograd.detect_anomaly():
train(it) # 训练
# 验证
if it % config.train.val_freq == 0 or it == config.train.max_iters:
val_loss = validate(it)
if best_loss is None or val_loss < best_loss:
logger.info(f'[Validate] Best val loss achieved: {val_loss:.6f}')
best_loss, best_iter = val_loss, it
ckpt_path = os.path.join(ckpt_dir, '%d.pt' % it)
for old_ckpt in glob.glob(f"{ckpt_dir}/*.pt"):
os.remove(old_ckpt)
# 保存模型 checkpoint
torch.save({
'config': config,
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'scheduler': scheduler.state_dict(),
'iteration': it,
}, ckpt_path)
else:
logger.info(f'[Validate] Val loss is not improved. '
f'Best val loss: {best_loss:.6f} at iter {best_iter}')
except KeyboardInterrupt:
logger.info('Terminating...')
至此,整个训练梯度引导模型的部分就已经结束了。其中可以看出,剩下的关键就在于梯度引导模型 DockGuideNet3D 的定义了。由于 DockGuideNet3D 会使用主模型中的噪音生成器等,因此先介绍 主模型 TargetDiff ,即 ScorePosNet3D 类。
四、主模型 ScorePosNet3D 代码解析
梯度引导模型 ScorePosNet3D (也是 TargetDiff 模型)来源于 models.molopt_score_model。
(其实这一部分的内容也适合其他的分子生成扩散模型,这些模型的相似度很高)
4.1 关于扩散模型的原理简介
扩散模型分为正向(添加噪音)和逆向(去噪音)过程,如下图。
在正向过程中,有0~T,共 T 个时刻。对于某一个时刻 t-1,都有一个 代表添加的正态分布噪音的权重, 数据
被添加噪音到
,这一过程的公式为:
将 1- 定义为
, 那么上述公式可写为:
通过迭代,以及整套分布噪音的性质,从 到
的过程可以使用如下公式直接计算得到:
其中, 为 从0~t 的
的累乘,即:
和
分别代表由
到
过程中,
和 正态分布噪音的系数。
去噪过程,也称之反向过程,即在 (真实或者神经网络预测的)和 t 时刻
的情况下,采样上一 t-1 时刻先验的
,其公式可以写为:
和
分别代表计算先验
时,
和
的系数,
则是噪音协方差的系数。前两项,可以称为
的均值,最后一项可以称为
的不确定性。在具体实践中,上述公式往往应用于坐标等连续变量。
但是在节点类型等分离变量时,使用贝叶斯推断,的计算公式如下:
在实际操作中,可以使用归一化来替代 q(v_t | v_0 ),即:
注:离散变量的目标是,通过有限的类别或状态进行转移,模型关注的是离散状态的变化。连续变量的目标是处理图像、信号等连续数据的生成和去噪,模型关注如何通过时间序列恢复出原始连续数据。
注:不管是坐标还是节点类型,我们都是输入 0 时刻和 t 时刻的,V,X,预测 和
, 我们计算的都是后验。扩散模型计算的损失是真实的后验(使用真实的
和
)和估计的后验(神经网络预测的
和
)分别计算的
和
之间的 KL 损失。节点类型和坐标在计算后验时,方法有所不同。
关于扩散模型的损失,在真实的扩散(向前)过程的逆过程 是已知的,这一过程是先验。我们会让神经网络
来预测
,即
, 然后,通过上述公式,计算
(这一过程为后验),即
,然后计算 后验的
和 先验的
的 KL 散度,做为神经网络的损失。因此,神经网络可以学习整个扩散/去噪过程中数据的分布。(注,实际上,在非分子生成领域,神经网络的损失是噪音的MSE,只有在最后几步 t(例如:1~5) 才是 KL 散度)
综上,为了训练一个扩散模型,有几个非常重要的参数,需要预先计算出来(往往在 __inti__ 函数中计算):
(1) 和
,分别代表在一步(由 t 到 t+1)添加噪音过程中,
和 噪音
的 ”系数“,以
为基础,
直接定义为一个列表;
(2) 和
分别代表由
到
过程中,
和 正态分布噪音
的系数,用于添加噪音的过程;
(3) 和
分别代表计算后验/先验的
时,
和
的系数,
则是噪音协方差的系数,用于计算损失;
4.2 主模型 ScorePosNet3D 代码解析
结合上述 3.1 过程中,对 ScorePosNet3D 其中的方法,逐个介绍。
作为一个分子生成扩散模型,具有两个重要的功能,一个是计算损失(训练)给优化器,对应的是 get_diffusion_loss() 方法;另一个是分子生成,对应的是 TargetDiff 自己的分子生成 sample_diffusion 函数。
此外,TagMol 也对 TargetDiff 的分子生成过程做了修改,以便实现梯度引导分子生成,在 ScorePosNet3D 还存在一个sample_guided_diffusion (或者 sample_multi_guided_diffusion ) 函数。
接下来,将按照初始化函数 __inti__、训练函数 get_diffusion_loss、分子生函数 sample_diffusion 的顺序对 ScorePosNet3D 进行全面介绍。梯度引导分子生成函数 sample_multi_guided_diffusion 已经在 2.6 部分进行了详细介绍,这里不重复。
4.2.1 __inti__ 函数
__inti__函数里面主要是预先设置了扩散模型中的参数,神经网络的超参数。
首先,(1) 基本设置:
def __init__(self, config, protein_atom_feature_dim, ligand_atom_feature_dim):
super().__init__()
'''
模型中神经网络的配置
扩散相关的超参数系数
'''
self.config = config
# variance schedule
# 神经网络输出的目标,noise 代表噪音, C0 代表真实值
self.model_mean_type = config.model_mean_type # ['noise', 'C0']
# 损失权重??
self.loss_v_weight = config.loss_v_weight
(2) 扩散模型的超参数设置:
首先是时间t 的列表
# 时间 t 的采样方法
self.sample_time_method = config.sample_time_method # ['importance', 'symmetric']
然后定义 β_t 和 α_t 列表:
# 创建 β_t 和 α_t 列表
if config.beta_schedule == 'cosine':
alphas = cosine_beta_schedule(config.num_diffusion_timesteps, config.pos_beta_s) ** 2
# print('cosine pos alpha schedule applied!')
betas = 1. - alphas
else:
betas = get_beta_schedule(
beta_schedule=config.beta_schedule,
beta_start=config.beta_start,
beta_end=config.beta_end,
num_diffusion_timesteps=config.num_diffusion_timesteps,
)
alphas = 1. - betas
其中, cosine_beta_schedule 定了一个 cosine 噪音调度器 ,也是很常用的噪音调度器,先算β_t, 然后计算α_t。代码如下:
def cosine_beta_schedule(timesteps, s=0.008):
"""
cosine schedule 噪音调度器
0~T 的 cos(1/t * 0.5 * Π),t 时刻的后累乘 除以 t前时刻的累乘, 然后再开根号
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
"""
steps = timesteps + 1
x = np.linspace(0, steps, steps)
alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
alphas = (alphas_cumprod[1:] / alphas_cumprod[:-1])
alphas = np.clip(alphas, a_min=0.001, a_max=1.)
# Use sqrt of this, so the alpha in our paper is the alpha_sqrt from the
# Gaussian diffusion in Ho et al.
alphas = np.sqrt(alphas)
return alphas
另外一个噪音调度器,先算α_t,然后计算 β_t ,代码如下:
def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
def sigmoid(x):
return 1 / (np.exp(-x) + 1)
if beta_schedule == "quad":
betas = (
np.linspace(
beta_start ** 0.5,
beta_end ** 0.5,
num_diffusion_timesteps,
dtype=np.float64,
)
** 2
)
elif beta_schedule == "linear":
betas = np.linspace(
beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
)
elif beta_schedule == "const":
betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1
betas = 1.0 / np.linspace(
num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64
)
elif beta_schedule == "sigmoid":
betas = np.linspace(-6, 6, num_diffusion_timesteps)
betas = sigmoid(betas) * (beta_end - beta_start) + beta_start
else:
raise NotImplementedError(beta_schedule)
assert betas.shape == (num_diffusion_timesteps,)
return betas
不同的噪音调度器,意味着不同的信息的衰减曲线。
计算添加噪音过程的 和
分别对应下述代码中的 self.sqrt_alphas_cumprod 和 self.sqrt_one_minus_alphas_cumprod,代表由
到
过程中,
和 正态分布噪音的系数,用于添加噪音的过程;
# α_t 的累乘
alphas_cumprod = np.cumprod(alphas, axis=0)
alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
# tensor 矩阵化
self.betas = to_torch_const(betas)
self.num_timesteps = self.betas.size(0)
self.alphas_cumprod = to_torch_const(alphas_cumprod)
self.alphas_cumprod_prev = to_torch_const(alphas_cumprod_prev)
# calculations for diffusion q(x_t | x_{t-1}) and others 添加噪音过程的超参数
# 添加噪音过程中 X_0 的比例
self.sqrt_alphas_cumprod = to_torch_const(np.sqrt(alphas_cumprod))
# 添加噪音过程中 噪音 的比例
self.sqrt_one_minus_alphas_cumprod = to_torch_const(np.sqrt(1. - alphas_cumprod))
self.sqrt_recip_alphas_cumprod = to_torch_const(np.sqrt(1. / alphas_cumprod))
self.sqrt_recipm1_alphas_cumprod = to_torch_const(np.sqrt(1. / alphas_cumprod - 1))
然后是计算先验/后验的,由 X_t 和 X_0 计算 X_t-1 的超参数, 和
,分别对应下面代码的 self.posterior_mean_c0_coef 和
self.posterior_mean_ct_coef, 代表 和
的系数;
则于代码中的 posterior_variance 对应,噪音协方差的系数。
# calculations for posterior q(x_{t-1} | x_t, x_0) 计算先验 X_t-1 的 超参数
# 给定当前状态 x_t 和 x_0 的情况下,前一个时间步 x_{t-1} 的方差。这些方差用于描述从 x_t 到 x_{t-1} 过程中引入的噪声量。
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
# X_0 的系数
self.posterior_mean_c0_coef = to_torch_const(betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
# X_t 的系数
self.posterior_mean_ct_coef = to_torch_const(
(1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))
# log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
self.posterior_var = to_torch_const(posterior_variance) # 后验方差
self.posterior_logvar = to_torch_const(np.log(np.append(self.posterior_var[1], self.posterior_var[1:]))) # 后验对数方差
对于节点特征而言,因为节点特征是 one-hot 的,为了数值稳定,需要将噪音添加在对数空间上操作,那么就需要计算对数空间的 和
, 以及噪音添加过程中的其他参数。节点特征的
和
分别对应下述代码中的 self.log_alphas_cumprod_v 和 self.log_one_minus_alphas_cumprod_v
# atom type diffusion schedule in log space,原子类型的扩散在对数空间实现,为了数值稳定
# 计算 α_t
if config.v_beta_schedule == 'cosine':
alphas_v = cosine_beta_schedule(self.num_timesteps, config.v_beta_s)
# print('cosine v alpha schedule applied!')
else:
raise NotImplementedError
# 计算 log(α_t)
log_alphas_v = np.log(alphas_v)
# 计算 log(α_t * α_t-1 .... α_0), 相当于α_t 的累积
log_alphas_cumprod_v = np.cumsum(log_alphas_v)
self.log_alphas_v = to_torch_const(log_alphas_v)
# 计算 log(1-α_t)
self.log_one_minus_alphas_v = to_torch_const(log_1_min_a(log_alphas_v))
self.log_alphas_cumprod_v = to_torch_const(log_alphas_cumprod_v)
# 计算 log((1-α_t) * (1-α_t-1) .... (1-α_0)), 相当于1-α_t 的累积
self.log_one_minus_alphas_cumprod_v = to_torch_const(log_1_min_a(log_alphas_cumprod_v))
(3)神经网络的超参数配置
小分子的输入节点向量维度,以及隐藏层向量维度,蛋白节点的输入向量维度
# model definition 神经网络模型定义
self.hidden_dim = config.hidden_dim # 隐藏层维度
self.num_classes = ligand_atom_feature_dim # 小分子的输入节点向量维度
if self.config.node_indicator: # ????, 参数是 True
emb_dim = self.hidden_dim - 1
else:
emb_dim = self.hidden_dim
# atom embedding 蛋白原子的嵌入层
self.protein_atom_emb = nn.Linear(protein_atom_feature_dim, emb_dim)
然后是,时间t 和小分子节点的嵌入层的设置。从代码中可以推测,时间t 的嵌入以后和小分子的嵌入合并在一起,二者形成了隐藏层的维度。
# time embedding 时间和小分子的原子嵌入
self.time_emb_dim = config.time_emb_dim
# 时间 t 的嵌入模式
self.time_emb_mode = config.time_emb_mode # ['simple', 'sin']
if self.time_emb_dim > 0:
# 简单嵌入
if self.time_emb_mode == 'simple':
self.ligand_atom_emb = nn.Linear(ligand_atom_feature_dim + 1, emb_dim)
# sin 嵌入
elif self.time_emb_mode == 'sin':
self.time_emb = nn.Sequential(
SinusoidalPosEmb(self.time_emb_dim),
nn.Linear(self.time_emb_dim, self.time_emb_dim * 4),
nn.GELU(),
nn.Linear(self.time_emb_dim * 4, self.time_emb_dim)
)
self.ligand_atom_emb = nn.Linear(ligand_atom_feature_dim + self.time_emb_dim, emb_dim)
else:
raise NotImplementedError
else:
self.ligand_atom_emb = nn.Linear(ligand_atom_feature_dim, emb_dim)
然后就是神经网络的初始化
self.refine_net_type = config.model_type # 神经网络,uni_o2,不是ENGG
self.refine_net = get_refine_net(self.refine_net_type, config)
# 节点特征网络
self.v_inference = nn.Sequential(
nn.Linear(self.hidden_dim, self.hidden_dim),
ShiftedSoftplus(),
nn.Linear(self.hidden_dim, ligand_atom_feature_dim),
)
get_refine_net 初始化神经网络,结合项目的配置文件,知道神网络阔使用的是 uni_o2,不是 ENGG。self.v_inference 用于从 uni_o2 的输出节点嵌入中,预测节点类型概率。
完整的__inti__代码如下:
def __init__(self, config, protein_atom_feature_dim, ligand_atom_feature_dim):
super().__init__()
'''
模型中神经网络的配置
扩散相关的超参数系数
'''
self.config = config
# variance schedule
# 神经网络输出的目标,noise 代表噪音, C0 代表真实值
self.model_mean_type = config.model_mean_type # ['noise', 'C0']
# 损失权重??
self.loss_v_weight = config.loss_v_weight
# self.v_mode = config.v_mode
# assert self.v_mode == 'categorical'
# self.v_net_type = getattr(config, 'v_net_type', 'mlp')
# self.bond_loss = getattr(config, 'bond_loss', False)
# self.bond_net_type = getattr(config, 'bond_net_type', 'pre_att')
# self.loss_bond_weight = getattr(config, 'loss_bond_weight', 0.)
# self.loss_non_bond_weight = getattr(config, 'loss_non_bond_weight', 0.)
# 时间 t 的采样方法
self.sample_time_method = config.sample_time_method # ['importance', 'symmetric']
# self.loss_pos_type = config.loss_pos_type # ['mse', 'kl']
# print(f'Loss pos mode {self.loss_pos_type} applied!')
# print(f'Loss bond net type: {self.bond_net_type} '
# f'bond weight: {self.loss_bond_weight} non bond weight: {self.loss_non_bond_weight}')
# 创建 β_t 和 α_t 列表
if config.beta_schedule == 'cosine':
alphas = cosine_beta_schedule(config.num_diffusion_timesteps, config.pos_beta_s) ** 2
# print('cosine pos alpha schedule applied!')
betas = 1. - alphas
else:
betas = get_beta_schedule(
beta_schedule=config.beta_schedule,
beta_start=config.beta_start,
beta_end=config.beta_end,
num_diffusion_timesteps=config.num_diffusion_timesteps,
)
alphas = 1. - betas
# α_t 的累乘
alphas_cumprod = np.cumprod(alphas, axis=0)
alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
# tensor 矩阵化
self.betas = to_torch_const(betas)
self.num_timesteps = self.betas.size(0)
self.alphas_cumprod = to_torch_const(alphas_cumprod)
self.alphas_cumprod_prev = to_torch_const(alphas_cumprod_prev)
# calculations for diffusion q(x_t | x_{t-1}) and others 添加噪音过程的超参数
# 添加噪音过程中 X_0 的比例
self.sqrt_alphas_cumprod = to_torch_const(np.sqrt(alphas_cumprod))
# 添加噪音过程中 噪音 的比例
self.sqrt_one_minus_alphas_cumprod = to_torch_const(np.sqrt(1. - alphas_cumprod))
self.sqrt_recip_alphas_cumprod = to_torch_const(np.sqrt(1. / alphas_cumprod))
self.sqrt_recipm1_alphas_cumprod = to_torch_const(np.sqrt(1. / alphas_cumprod - 1))
# calculations for posterior q(x_{t-1} | x_t, x_0) 计算先验 X_t-1 的 超参数
# 给定当前状态 x_t 和 x_0 的情况下,前一个时间步 x_{t-1} 的方差。这些方差用于描述从 x_t 到 x_{t-1} 过程中引入的噪声量。
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
# X_0 的系数
self.posterior_mean_c0_coef = to_torch_const(betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
# X_t 的系数
self.posterior_mean_ct_coef = to_torch_const(
(1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))
# log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
self.posterior_var = to_torch_const(posterior_variance) # 后验方差
self.posterior_logvar = to_torch_const(np.log(np.append(self.posterior_var[1], self.posterior_var[1:]))) # 后验对数方差
# atom type diffusion schedule in log space,原子类型的扩散在对数空间实现,为了数值稳定
# 计算 α_t
if config.v_beta_schedule == 'cosine':
alphas_v = cosine_beta_schedule(self.num_timesteps, config.v_beta_s)
# print('cosine v alpha schedule applied!')
else:
raise NotImplementedError
# 计算 log(α_t)
log_alphas_v = np.log(alphas_v)
# 计算 log(α_t * α_t-1 .... α_0), 相当于α_t 的累积
log_alphas_cumprod_v = np.cumsum(log_alphas_v)
self.log_alphas_v = to_torch_const(log_alphas_v)
# 计算 log(1-α_t)
self.log_one_minus_alphas_v = to_torch_const(log_1_min_a(log_alphas_v))
self.log_alphas_cumprod_v = to_torch_const(log_alphas_cumprod_v)
# 计算 log((1-α_t) * (1-α_t-1) .... (1-α_0)), 相当于1-α_t 的累积
self.log_one_minus_alphas_cumprod_v = to_torch_const(log_1_min_a(log_alphas_cumprod_v))
self.register_buffer('Lt_history', torch.zeros(self.num_timesteps))
self.register_buffer('Lt_count', torch.zeros(self.num_timesteps))
# model definition 神经网络模型定义
self.hidden_dim = config.hidden_dim # 隐藏层维度
self.num_classes = ligand_atom_feature_dim # 小分子的输入节点向量维度
if self.config.node_indicator: # ????, 参数是 True
emb_dim = self.hidden_dim - 1
else:
emb_dim = self.hidden_dim
# atom embedding 蛋白原子的嵌入层
self.protein_atom_emb = nn.Linear(protein_atom_feature_dim, emb_dim)
# center pos 中心模式
self.center_pos_mode = config.center_pos_mode # ['none', 'protein']
# time embedding 时间和小分子的原子嵌入
self.time_emb_dim = config.time_emb_dim
# 时间 t 的嵌入模式
self.time_emb_mode = config.time_emb_mode # ['simple', 'sin']
if self.time_emb_dim > 0:
# 简单嵌入
if self.time_emb_mode == 'simple':
self.ligand_atom_emb = nn.Linear(ligand_atom_feature_dim + 1, emb_dim)
# sin 嵌入
elif self.time_emb_mode == 'sin':
self.time_emb = nn.Sequential(
SinusoidalPosEmb(self.time_emb_dim),
nn.Linear(self.time_emb_dim, self.time_emb_dim * 4),
nn.GELU(),
nn.Linear(self.time_emb_dim * 4, self.time_emb_dim)
)
self.ligand_atom_emb = nn.Linear(ligand_atom_feature_dim + self.time_emb_dim, emb_dim)
else:
raise NotImplementedError
else:
self.ligand_atom_emb = nn.Linear(ligand_atom_feature_dim, emb_dim)
self.refine_net_type = config.model_type # 神经网络,uni_o2,不是ENGG
self.refine_net = get_refine_net(self.refine_net_type, config)
# 节点特征网络
self.v_inference = nn.Sequential(
nn.Linear(self.hidden_dim, self.hidden_dim),
ShiftedSoftplus(),
nn.Linear(self.hidden_dim, ligand_atom_feature_dim),
)
4.2.2 get_diffusion_loss 函数
get_diffusion_loss 函数输入蛋白坐标、蛋白节点类型、蛋白图 mask、小分子坐标、小分子节点类型、小分子图 mask 、以及时间步 t (默认为 None),计算扩散模型的损失。
(1) 计算批次中图的数量,去坐标中心
num_graphs = batch_protein.max().item() + 1 # 蛋白/小分子图数量
protein_pos, ligand_pos, _ = center_pos(
protein_pos, ligand_pos, batch_protein, batch_ligand, mode=self.center_pos_mode) # 坐标中心归0
(2) 采样时间 t,以及时间 t 出现的概率
if time_step is None:
# 采样
time_step, pt = self.sample_time(num_graphs, protein_pos.device, self.sample_time_method)
else:
# 按参数指定
pt = torch.ones_like(time_step).float() / self.num_timesteps
关于 self.sample_time 函数,按照权重或者对称随机采样与图数量一致的时间步 t ,代码及注释如下。
def sample_time(self, num_graphs, device, method):
if method == 'importance':
# 使用基于权重的重要性采样方法
if not (self.Lt_count > 10).all():
return self.sample_time(num_graphs, device, method='symmetric')
Lt_sqrt = torch.sqrt(self.Lt_history + 1e-10) + 0.0001
Lt_sqrt[0] = Lt_sqrt[1] # Overwrite decoder term with L1.
pt_a照分布概率多项式采样 时间步
time_step = torch.multinomial(pt_all, num_samples=num_graphs, replacement=True)
# 对应时间步的概率
pt = pt_all.gather(dim=0, index=time_step)
return time_step, pt
elif method == 'symmetric':
# 对称采样
# 随机抽取 图一半数量的 t
time_step = torch.randint(
0, self.num_timesteps, size=(num_graphs // 2 + 1,), device=device)
# 对称生成另一半数量的时间步 t
time_step = torch.cat(
[time_step, self.num_timesteps - time_step - 1], dim=0)[:num_graphs]
# 概率
pt = torch.ones_like(time_step).float() / self.num_timesteps
return time_step, ptll = Lt_sqrt / Lt_sqrt.sum() # 分布概率
# 按
else:
raise ValueError
(3) 提取时间步 t 对应的 :
a = self.alphas_cumprod.index_select(0, time_step) # (num_graphs, )
(4) 为小分子节点特征以及坐标添加噪音,分别生成添加噪音后的 ligand_v_perturbed, ligand_pos_perturbed。
# 扩展 α_cumprod,t 的维度与分子节点数相同
a_pos = a[batch_ligand].unsqueeze(-1) # (num_ligand_atoms, 1)
# 分子的噪音
pos_noise = torch.zeros_like(ligand_pos)
pos_noise.normal_()
# 扰动小分子的坐标,公式:Xt = a.sqrt() * X0 + (1-a).sqrt() * eps
ligand_pos_perturbed = a_pos.sqrt() * ligand_pos + (1.0 - a_pos).sqrt() * pos_noise # pos_noise * std
# 节点类型 one-hot 的 log 概率,然后添加噪音扰动, 公式:Vt = a * V0 + (1-a) / K ,
log_ligand_v0 = index_to_log_onehot(ligand_v, self.num_classes)
ligand_v_perturbed, log_ligand_vt = self.q_v_sample(log_ligand_v0, time_step, batch_ligand)
其中,坐标和节点特征添加噪音对应上文公式:
关于 index_to_log_onehot 函数,给定的节点特征转换为 one-hot 编码的形式,并返回其对数表示(概率),代码如下:
def index_to_log_onehot(x, num_classes):
'''
定的节点特征转换为 one-hot 编码的形式,并返回其对数表示(概率)
'''
assert x.max().item() < num_classes, f'Error: {x.max().item()} >= {num_classes}'
x_onehot = F.one_hot(x, num_classes)
# permute_order = (0, -1) + tuple(range(1, len(x.size())))
# x_onehot = x_onehot.permute(permute_order)
log_x = torch.log(x_onehot.float().clamp(min=1e-30))
return log_x
关于 q_v_sample 函数,在对数空间,添加噪音到节点特征, 代码如下:
def q_v_sample(self, log_v0, t, batch):
'''
在对数空间,添加噪音到节点特征
'''
# 添加噪音到对数概率
log_qvt_v0 = self.q_v_pred(log_v0, t, batch)
# 采样节点特征各类别
sample_index = log_sample_categorical(log_qvt_v0)
# 类别概率转化为 one-hot (去对数化)
log_sample = index_to_log_onehot(sample_index, self.num_classes)
return sample_index, log_sample
其中,self.q_v_pred 实现对数空间的添加节点特征噪音,代码如下:
def q_v_pred(self, log_v0, t, batch):
# compute q(vt | v0)
log_cumprod_alpha_t = extract(self.log_alphas_cumprod_v, t, batch)
log_1_min_cumprod_alpha = extract(self.log_one_minus_alphas_cumprod_v, t, batch)
# 节点特征噪音为 1/类别数
log_probs = log_add_exp(
log_v0 + log_cumprod_alpha_t,
log_1_min_cumprod_alpha - np.log(self.num_classes)
)
return log_probs
index_to_log_onehot 前面已经介绍过,不再重复。log_sample_categorical 代码如下:
def log_sample_categorical(logits):
'''
基于 Gumbel 分布 的离散分布采样方法。
它通过对每个类别的 logits 添加 Gumbel 噪声,
然后选择具有最大值的类别。这样的方法可以避免直接计算 softmax,
并且由于 Gumbel 分布的性质,这种采样方式是从 logits 所表示的分布中采样的
'''
uniform = torch.rand_like(logits)
gumbel_noise = -torch.log(-torch.log(uniform + 1e-30) + 1e-30)
sample_index = (gumbel_noise + logits).argmax(dim=-1)
# sample_onehot = F.one_hot(sample, self.num_classes)
# log_sample = index_to_log_onehot(sample, self.num_classes)
return sample_index
(5) 基于扰动的小分子坐标和小分子节点类型,以及真实的蛋白坐标和节点类型,调用 forward 函数预测节点类型和坐标,代码如下:
preds = self(
protein_pos=protein_pos,
protein_v=protein_v,
batch_protein=batch_protein,
init_ligand_pos=ligand_pos_perturbed,
init_ligand_v=ligand_v_perturbed,
batch_ligand=batch_ligand,
time_step=time_step
)
pred_ligand_pos, pred_ligand_v = preds['pred_ligand_pos'], preds['pred_ligand_v']
关于 forward 函数, 将时间步 t 与扰动后的小分子节点特征合并,然后小分子和蛋白的节点特征进行嵌入,随后,输入到神经网络 refine_net 中预测小分子的初始的坐标和节点特征。代码如下:
def forward(self, protein_pos, protein_v, batch_protein, init_ligand_pos, init_ligand_v, batch_ligand,
time_step=None, return_all=False, fix_x=False, guide=None):
batch_size = batch_protein.max().item() + 1 # 批次中图的数量
init_ligand_v = F.one_hot(init_ligand_v, self.num_classes).float() # 小分子节点类型转为 one-hot
## time embedding - currently not useful, generally passed as zero only
## if used, it is concatenated to the ligand features
# 时间步 t 进行嵌入,然后与小分子的节点类型进行concat
if self.time_emb_dim > 0:
if self.time_emb_mode == 'simple':
input_ligand_feat = torch.cat([
init_ligand_v,
(time_step / self.num_timesteps)[batch_ligand].unsqueeze(-1)
], -1)
elif self.time_emb_mode == 'sin':
time_feat = self.time_emb(time_step)
input_ligand_feat = torch.cat([init_ligand_v, time_feat], -1)
else:
raise NotImplementedError
else:
input_ligand_feat = init_ligand_v
## convert one-hot features into the embedding space
# 蛋白和小分子的节点类型进行嵌入
h_protein = self.protein_atom_emb(protein_v)
init_ligand_h = self.ligand_atom_emb(input_ligand_feat)
## add 0 to the end of hidden embedding of protein for every atom
## add 1 to the end of hidden embedding of protein for every atom
if self.config.node_indicator:
# 标记 哪个节点是属于蛋白,哪个节点属于小分子
h_protein = torch.cat([h_protein, torch.zeros(len(h_protein), 1).to(h_protein)], -1)
init_ligand_h = torch.cat([init_ligand_h, torch.ones(len(init_ligand_h), 1).to(h_protein)], -1)
## combines hidden states of protein and ligand into one set
## mask_ligand is used to keep track of which atom belongs to which type
## needed for forward pass of refine_net (uni_transformer)
# 合并蛋白和小分子的坐标、节点特征以及批次信息,按照批次信息进行重排顺序
h_all, pos_all, batch_all, mask_ligand = compose_context(
h_protein=h_protein,
h_ligand=init_ligand_h,
pos_protein=protein_pos,
pos_ligand=init_ligand_pos,
batch_protein=batch_protein,
batch_ligand=batch_ligand,
)
# 预测真实的 x 和 h
if guide is None:
# 没有引导的情况下,TargetDiff
outputs = self.refine_net(h_all, pos_all, mask_ligand, batch_all, return_all=return_all, fix_x=fix_x)
else:
# 在有梯度引导的情况下,TagMol
outputs = self.refine_net.forward_guided(
h_all, pos_all, mask_ligand, batch_all, return_all=return_all, fix_x=fix_x,
guide=guide
)
final_pos, final_h = outputs['x'], outputs['h']
## final positions of ligand only is needed (protein positions are not changed or used)
final_ligand_pos, final_ligand_h = final_pos[mask_ligand], final_h[mask_ligand]
## predict classes of atom-categories at the end of all the layers
# 进一步预测小分子的节点类型
final_ligand_v = self.v_inference(final_ligand_h)
# 输出内容
preds = {
'pred_ligand_pos': final_ligand_pos,
'pred_ligand_v': final_ligand_v,
'final_h': final_h,
'final_ligand_h': final_ligand_h
}
if return_all:
final_all_pos, final_all_h = outputs['all_x'], outputs['all_h']
final_all_ligand_pos = [pos[mask_ligand] for pos in final_all_pos]
final_all_ligand_v = [self.v_inference(h[mask_ligand]) for h in final_all_h]
preds.update({
'layer_pred_ligand_pos': final_all_ligand_pos,
'layer_pred_ligand_v': final_all_ligand_v
})
return preds
神经网络 refine_net 对应的是 uni_o2 模型,将在后面介绍。
(6) 提取小分子的坐标/噪音,然后计算坐标部分 x_0 的损失,仅计算 mse 损失,如下:
pred_ligand_pos, pred_ligand_v = preds['pred_ligand_pos'], preds['pred_ligand_v']
pred_pos_noise = pred_ligand_pos - ligand_pos_perturbed
# atom position
if self.model_mean_type == 'noise':
# 神经网络预测的是噪音
# 根据噪音,计算真实的坐标
pos0_from_e = self._predict_x0_from_eps(
xt=ligand_pos_perturbed, eps=pred_pos_noise, t=time_step, batch=batch_ligand)
# 计算后验 t-1 的小分子坐标 (后面没使用)
pos_model_mean = self.q_pos_posterior(
x0=pos0_from_e, xt=ligand_pos_perturbed, t=time_step, batch=batch_ligand)
elif self.model_mean_type == 'C0':
# 神经网络预测是的真实的小分子的坐标和节点类型
# 计算后验 t-1 的小分子坐标 (后面没使用)
pos_model_mean = self.q_pos_posterior(
x0=pred_ligand_pos, xt=ligand_pos_perturbed, t=time_step, batch=batch_ligand)
else:
raise ValueError
注,虽然,在其中使用 self.q_pos_posterior 函数输入 X_0 和 X_t 计算了坐标 x_t 的 前一步 x_t-1,但是,并没有计算 x_t-1 的先验真实值与后验预测值的 KL 散度作为损失,而是直接使用 x_0 的 mse 作为损失。self.q_pos_posterior 函数与上文提及的公式对应:
self.q_pos_posterior 函数的代码如下:
def q_pos_posterior(self, x0, xt, t, batch):
# 计算 X_t-1
# Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0)
pos_model_mean = extract(self.posterior_mean_c0_coef, t, batch) * x0 + \
extract(self.posterior_mean_ct_coef, t, batch) * xt
return pos_model_mean
(7) 计算先验的和后验的节点类型 v_t-1 的 KL 损失
# atom type loss 计算节点类型的 KL 损失
log_ligand_v_recon = F.log_softmax(pred_ligand_v, dim=-1) # 预测的小分子节点类型 one-hot
# 后验的 V_t-1 (预测的)
log_v_model_prob = self.q_v_posterior(log_ligand_v_recon, log_ligand_vt, time_step, batch_ligand)
# 先验的 V_t-1 (先验的)
log_v_true_prob = self.q_v_posterior(log_ligand_v0, log_ligand_vt, time_step, batch_ligand)
# 计算节点的类型 V_t-1 的 KL 损失
kl_v = self.compute_v_Lt(log_v_model_prob=log_v_model_prob, log_v0=log_ligand_v0,
log_v_true_prob=log_v_true_prob, t=time_step, batch=batch_ligand)
loss_v = torch.mean(kl_v)
其中,self.q_v_posterior 分别计算了先验的和后验的 V_t-1 的概率。需要注意的是,由于节点类型时离散变量,而使用了如下公式:
self.q_v_posterior 代码如下。(self.q_v_pred 上文已经介绍过,不再重复)
def q_v_posterior(self, log_v0, log_vt, t, batch):
# 对用的公式,q(vt-1 | vt, v0) = q(vt | vt-1, x0) * q(vt-1 | x0) / q(vt | x0)
# 注意这里没有使用 坐标类似的公式
t_minus_1 = t - 1
# Remove negative values, will not be used anyway for final decoder 去除负值
t_minus_1 = torch.where(t_minus_1 < 0, torch.zeros_like(t_minus_1), t_minus_1)
# q(vt-1 | x0)
log_qvt1_v0 = self.q_v_pred(log_v0, t_minus_1, batch)
# self.q_v_pred_one_timestep 计算 q(vt | vt-1, x0)
unnormed_logprobs = log_qvt1_v0 + self.q_v_pred_one_timestep(log_vt, t, batch)
# 归一化 V_t-1 的概率, 即 q(vt-1 | vt, v0)
log_vt1_given_vt_v0 = unnormed_logprobs - torch.logsumexp(unnormed_logprobs, dim=-1, keepdim=True)
return log_vt1_given_vt_v0
其中,self.q_v_pred 计算 q(vt-1 | x0), self.q_v_pred_one_timestep 实现一步的扩散 q(v_t | v_t-1) 。
但是似乎,self.q_v_pred_one_timestep(log_vt, t, batch) 有错误,应该改为 self.q_v_pred_one_timestep(log_qvt1_v0 , t, batch),但是不清楚这是不是真的错误,以及错误的影响。要试一下。
关于 self.compute_v_Lt, 计算真实的 V_t-1和预测的V_t-1 之间的 KL 散度(当t 不等于0 时)。当 t 等于0时,则计算真实的 V_t-1和预测的V_t-1 之间的 NLL,保证生成的分子与真实分子一致。代码如下:
def compute_v_Lt(self, log_v_model_prob, log_v0, log_v_true_prob, t, batch):
'''
KL 散度(KL divergence)或负对数似然(Negative Log Likelihood, NLL)
'''
kl_v = categorical_kl(log_v_true_prob, log_v_model_prob) # [num_atoms, ] # KL 散度
decoder_nll_v = -log_categorical(log_v0, log_v_model_prob) # L0 # 负对数似然
assert kl_v.shape == decoder_nll_v.shape
mask = (t == 0).float()[batch] # 如果 t=0 则不计算损失 KL 损失只计算 NLL
loss_v = scatter_mean(mask * decoder_nll_v + (1. - mask) * kl_v, batch, dim=0) # KL 散度或者 NLL
return loss_v
其中,categorical_kl 和 log_categoricald 的代码如下:
def log_categorical(log_x_start, log_prob):
return (log_x_start.exp() * log_prob).sum(dim=1)
def categorical_kl(log_prob1, log_prob2):
kl = (log_prob1.exp() * (log_prob1 - log_prob2)).sum(dim=1)
return kl
(8) 计算总损失,并返回预测的小分子坐标、节点类型,以及噪音和损失。代码如下:(注意,坐标的 mse 的权重为 100, 按照配置文件)
# 总损失为:坐标 X_0 的 mse(默认权重 100), 以及 V_t-1 的 KL 损失
loss = loss_pos + loss_v * self.loss_v_weight
return {
'loss_pos': loss_pos,
'loss_v': loss_v,
'loss': loss,
'x0': ligand_pos,
'pred_ligand_pos': pred_ligand_pos,
'pred_ligand_v': pred_ligand_v,
'pred_pos_noise': pred_pos_noise,
'ligand_v_recon': F.softmax(pred_ligand_v, dim=-1)
}
完整的 get_diffusion_loss 函数代码如下:
def get_diffusion_loss(
self, protein_pos, protein_v, batch_protein, ligand_pos, ligand_v, batch_ligand, time_step=None
):
num_graphs = batch_protein.max().item() + 1 # 蛋白/小分子图数量
protein_pos, ligand_pos, _ = center_pos(
protein_pos, ligand_pos, batch_protein, batch_ligand, mode=self.center_pos_mode) # 坐标中心归0
# 1. sample noise levels # 设置时间 t, 采样/按参数指定
if time_step is None:
# 采样
time_step, pt = self.sample_time(num_graphs, protein_pos.device, self.sample_time_method)
else:
# 按参数指定
pt = torch.ones_like(time_step).float() / self.num_timesteps
a = self.alphas_cumprod.index_select(0, time_step) # (num_graphs, )
# 2. perturb pos and v
# 扩展 α_cumprod,t 的维度与分子节点数相同
a_pos = a[batch_ligand].unsqueeze(-1) # (num_ligand_atoms, 1)
# 分子的噪音
pos_noise = torch.zeros_like(ligand_pos)
pos_noise.normal_()
# 扰动小分子的坐标,公式:Xt = a.sqrt() * X0 + (1-a).sqrt() * eps
ligand_pos_perturbed = a_pos.sqrt() * ligand_pos + (1.0 - a_pos).sqrt() * pos_noise # pos_noise * std
# 节点类型 one-hot 的 log 概率,然后添加噪音扰动, 公式:Vt = a * V0 + (1-a) / K ,
log_ligand_v0 = index_to_log_onehot(ligand_v, self.num_classes)
ligand_v_perturbed, log_ligand_vt = self.q_v_sample(log_ligand_v0, time_step, batch_ligand)
# 3. forward-pass NN, feed perturbed pos and v, output noise
# forward 函数预测真实的小分子节点类型和坐标
preds = self(
protein_pos=protein_pos,
protein_v=protein_v,
batch_protein=batch_protein,
init_ligand_pos=ligand_pos_perturbed,
init_ligand_v=ligand_v_perturbed,
batch_ligand=batch_ligand,
time_step=time_step
)
pred_ligand_pos, pred_ligand_v = preds['pred_ligand_pos'], preds['pred_ligand_v']
pred_pos_noise = pred_ligand_pos - ligand_pos_perturbed
# atom position
if self.model_mean_type == 'noise':
# 神经网络预测的是噪音
# 根据噪音,计算真实的坐标
pos0_from_e = self._predict_x0_from_eps(
xt=ligand_pos_perturbed, eps=pred_pos_noise, t=time_step, batch=batch_ligand)
# 计算后验 t-1 的小分子坐标 (后面没使用)
pos_model_mean = self.q_pos_posterior(
x0=pos0_from_e, xt=ligand_pos_perturbed, t=time_step, batch=batch_ligand)
elif self.model_mean_type == 'C0':
# 神经网络预测是的真实的小分子的坐标和节点类型
# 计算后验 t-1 的小分子坐标 (后面没使用)
pos_model_mean = self.q_pos_posterior(
x0=pred_ligand_pos, xt=ligand_pos_perturbed, t=time_step, batch=batch_ligand)
else:
raise ValueError
# atom pos loss 计算节点坐标的 MSE 损失
if self.model_mean_type == 'C0':
target, pred = ligand_pos, pred_ligand_pos
elif self.model_mean_type == 'noise':
target, pred = pos_noise, pred_pos_noise
else:
raise ValueError
loss_pos = scatter_mean(((pred - target) ** 2).sum(-1), batch_ligand, dim=0)
loss_pos = torch.mean(loss_pos)
# atom type loss 计算节点类型的 KL 损失
log_ligand_v_recon = F.log_softmax(pred_ligand_v, dim=-1) # 预测的小分子节点类型 one-hot
# 后验的 V_t-1 (预测的)
log_v_model_prob = self.q_v_posterior(log_ligand_v_recon, log_ligand_vt, time_step, batch_ligand)
# 先验的 V_t-1 (先验的)
log_v_true_prob = self.q_v_posterior(log_ligand_v0, log_ligand_vt, time_step, batch_ligand)
# 计算节点的类型 V_t-1 的 KL 损失
kl_v = self.compute_v_Lt(log_v_model_prob=log_v_model_prob, log_v0=log_ligand_v0,
log_v_true_prob=log_v_true_prob, t=time_step, batch=batch_ligand)
loss_v = torch.mean(kl_v)
# 总损失为:坐标 X_0 的 mse(默认权重 100), 以及 V_t-1 的 KL 损失
loss = loss_pos + loss_v * self.loss_v_weight
return {
'loss_pos': loss_pos,
'loss_v': loss_v,
'loss': loss,
'x0': ligand_pos,
'pred_ligand_pos': pred_ligand_pos,
'pred_ligand_v': pred_ligand_v,
'pred_pos_noise': pred_pos_noise,
'ligand_v_recon': F.softmax(pred_ligand_v, dim=-1)
}
总结一下 get_diffusion_loss 函数,调用 forward 函数预测真实的坐标和节点类型,然后分别计算了坐标的 mse 损失,以及节点类型的 q(v_t-1 | v_t, v_0) 的 真实的和基于预测值的 KL 损失。在权重上,坐标的 mse 损失的权重是 KL 损失的 100 倍。
4.2.3 sample_diffusion 函数
sample_diffusion 函数是 TargetDiff 的分子生成函数(注:不是 TagMol 的分子生生成函数),用于逐步对初始化的小分子坐标和节点(均含有噪音)进行去噪。
sample_diffusion 函数输入 初始化的小分子坐标 init_ligand_pos 和小分子节点特征 init_ligand_v 、蛋白坐标 protein_pos、蛋白节点 protein_v、小分子mask 和蛋白的 mask batch_protein, 输入小分子去噪过程中的节点特征 v0_pred_traj 和坐标 pos_traj、去噪过程中的神经网络预测的 v_0 pos0_traj 和 x_0 v0_pred_traj,以及最终生成的小分子坐标 ligand_pos 和节点特征 ligand_v。
这一部分 sample_diffusion 函数的代码与 2.6 部分 sample_multi_guided_diffusion 函数非常相似。只是少了梯度计算,并更新原子坐标的过程。
首先是基本设置,包括:去噪步数、批次中图数、坐标去中心、初始化去噪过程记录列表。
然后,就是在 for 循环里面,按照 T 的倒序,不断输入当前的 X_t 和 V_t 使用训练好的神经网络面预测 X_0 和 V_0,然后基于 X_0 和 V_0 以及当前的 X_t 和 V_t 使用去噪公式计算 X_t-1 和V_t-1,然后使用 X_t-1 和 V_t-1 替换 X_t 和 V_t。以此反复进行去噪,直到 T=0.
完整代码如下:
@torch.no_grad()
def sample_diffusion(self, protein_pos, protein_v, batch_protein,
init_ligand_pos, init_ligand_v, batch_ligand,
num_steps=None, center_pos_mode=None, pos_only=False):
if num_steps is None:
num_steps = self.num_timesteps
# 批次中图的数量
num_graphs = batch_protein.max().item() + 1
## Shifts the origin to the centre of mass of protein
## new protein positions, new ligand position and the difference from original position is in the offset
# 去坐标中心
protein_pos, init_ligand_pos, offset = center_pos(
protein_pos, init_ligand_pos, batch_protein, batch_ligand, mode=center_pos_mode)
pos_traj, v_traj = [], []
v0_pred_traj, vt_pred_traj, pos0_traj = [], [], []
ligand_pos, ligand_v = init_ligand_pos, init_ligand_v
## time sequence - going from 1000 to 1000 - num_steps
time_seq = list(reversed(range(self.num_timesteps - num_steps, self.num_timesteps)))
# 逐步去噪
for i in tqdm(time_seq, desc='sampling', total=len(time_seq)):
# 时间 t 扩散至图的数量
t = torch.full(size=(num_graphs,), fill_value=i, dtype=torch.long, device=protein_pos.device)
# 预测 v_0 和 x_0
preds = self(
protein_pos=protein_pos,
protein_v=protein_v,
batch_protein=batch_protein,
init_ligand_pos=ligand_pos,
init_ligand_v=ligand_v,
batch_ligand=batch_ligand,
time_step=t
)
# Compute posterior mean and variance 提取坐标预测值
if self.model_mean_type == 'noise':
pred_pos_noise = preds['pred_ligand_pos'] - ligand_pos
pos0_from_e = self._predict_x0_from_eps(xt=ligand_pos, eps=pred_pos_noise, t=t, batch=batch_ligand)
v0_from_e = preds['pred_ligand_v']
elif self.model_mean_type == 'C0':
pos0_from_e = preds['pred_ligand_pos']
v0_from_e = preds['pred_ligand_v']
else:
raise ValueError
# 计算坐标的 X_t-1
pos_model_mean = self.q_pos_posterior(x0=pos0_from_e, xt=ligand_pos, t=t, batch=batch_ligand)
# t 时刻的坐标方差
pos_log_variance = extract(self.posterior_logvar, t, batch_ligand)
# no noise when t == 0
nonzero_mask = (1 - (t == 0).float())[batch_ligand].unsqueeze(-1)
# 采样 X_t-1
ligand_pos_next = pos_model_mean + nonzero_mask * (0.5 * pos_log_variance).exp() * torch.randn_like(
ligand_pos)
ligand_pos = ligand_pos_next
if not pos_only:
# 预测的 V_0 (对数概率)
log_ligand_v_recon = F.log_softmax(v0_from_e, dim=-1)
# one-hot V_0
log_ligand_v = index_to_log_onehot(ligand_v, self.num_classes)
# V_t-1
log_model_prob = self.q_v_posterior(log_ligand_v_recon, log_ligand_v, t, batch_ligand)
# 采样 V_t-1
ligand_v_next = log_sample_categorical(log_model_prob)
v0_pred_traj.append(log_ligand_v_recon.clone().cpu())
vt_pred_traj.append(log_model_prob.clone().cpu())
ligand_v = ligand_v_next
# 记录采样轨迹
ori_ligand_pos0 = pos0_from_e + offset[batch_ligand] # 去噪过程中神经网络预测的 X_0
ori_ligand_pos = ligand_pos + offset[batch_ligand] # 去噪过程中采样的的 X_t
pos_traj.append(ori_ligand_pos.clone().cpu())
pos0_traj.append(ori_ligand_pos0.clone().cpu())
v_traj.append(ligand_v.clone().cpu()) # 去噪过程中采样的的 v_t
ligand_pos = ligand_pos + offset[batch_ligand]
return {
'pos': ligand_pos,
'v': ligand_v,
'pos_traj': pos_traj,
'pos0_traj': pos0_traj,
'v_traj': v_traj,
'v0_traj': v0_pred_traj,
'vt_traj': vt_pred_traj
}
至此, TargetDiff 模型(即,主模型 ScorePosNet3D )就已经全部介绍完成。
五、梯度引导模型 (DockGuideNet3D)
在 3.1 部分,训练梯度引导模型中,调用了 DockGuideNet3D 的 get_loss 方法计算梯度引导模型的损失。在分子生成的 2.3 梯度引导的分子分生成中,调用了 DockGuideNet3D 的 get_gradients_guide 方法来计算梯度。
DockGuideNet3D 在训练过程中,使用扰动后的小分子坐标和节点类型,以Binding Affinity 作为训练的 loss 目标。 小分子坐标和节点类型的扰动模式与主模型 ScorePosNet3D 完全相同。
下面,将按照__init__ 、get_loss、get_gradients_guide 对 DockGuideNet3D 进行全面的介绍。(注:在 GitHub 中,DockGuideNet3D 的代码很长,因为直接继承了很多 ScorePosNet3D 的代码修改而来,里面包含了很多没有使用到的函数。此外,由于在时间t采样等部分,与 ScorePosNet3D 完全相同,因此,这里对 DockGuideNet3D 的介绍会简单一些。)
5.1 __init__
(1)神经网络的预测模式、坐标 mse 损失权重、时间 t 采样的方式、
# 神经网络预测 噪音还是真实值,默认是 C0 真实值
self.model_mean_type = config.model_mean_type # ['noise', 'C0']
self.loss_v_weight = config.loss_v_weight
# 时间 t 的 采样,重要性采样 importance 或者 对称性采样 symmetric
self.sample_time_method = config.sample_time_method # ['importance', 'symmetric']
(2)坐标的噪音调度器 β及其参数(α,α累乘,根号下1-α,计算 x_t-1 的系数等)
# 噪音调度器 β 每一步添加噪音的比例; α (即,1-β) 每一步保留原信息的比例
if config.beta_schedule == 'cosine':
alphas = cosine_beta_schedule(config.num_diffusion_timesteps, config.pos_beta_s) ** 2
# print('cosine pos alpha schedule applied!')
betas = 1. - alphas
else:
betas = get_beta_schedule(
beta_schedule=config.beta_schedule,
beta_start=config.beta_start,
beta_end=config.beta_end,
num_diffusion_timesteps=config.num_diffusion_timesteps,
)
alphas = 1. - betas
alphas_cumprod = np.cumprod(alphas, axis=0) # α的累乘
alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) # α的累乘,从t=0开始
self.betas = to_torch_const(betas)
self.num_timesteps = self.betas.size(0)
self.alphas_cumprod = to_torch_const(alphas_cumprod) # α的累乘
self.alphas_cumprod_prev = to_torch_const(alphas_cumprod_prev) # α的累乘,从t=0开始
# 从 x_0 到 x_t 的过程中, x_0 的比例
self.sqrt_alphas_cumprod = to_torch_const(np.sqrt(alphas_cumprod)) # 根号下 α的累乘,从t=0开始
# 从 x_0 到 x_t 的过程中, 噪音的比例
self.sqrt_one_minus_alphas_cumprod = to_torch_const(np.sqrt(1. - alphas_cumprod)) # 根号下(1-α的累乘)
self.sqrt_recip_alphas_cumprod = to_torch_const(np.sqrt(1. / alphas_cumprod))
self.sqrt_recipm1_alphas_cumprod = to_torch_const(np.sqrt(1. / alphas_cumprod - 1))
# 每个时刻 t 下的 x_t 的概率(不确定性/方差)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
# 计算 x_{t-1} 的 x_0 的系数
self.posterior_mean_c0_coef = to_torch_const(betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
# 计算 x_{t-1} 的 x_t 的系数
self.posterior_mean_ct_coef = to_torch_const(
(1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))
self.posterior_var = to_torch_const(posterior_variance) # 每个时刻 t 下的 x_t 的概率(不确定性/方差)
self.posterior_logvar = to_torch_const(np.log(np.append(self.posterior_var[1], self.posterior_var[1:]))) # 对数方差
(3) 节点特征相关的噪音调度器
# # 节点类型的噪音调度器
# 计算 α_t
if config.v_beta_schedule == 'cosine':
alphas_v = cosine_beta_schedule(self.num_timesteps, config.v_beta_s)
# print('cosine v alpha schedule applied!')
else:
raise NotImplementedError
# 计算 log(α_t)
log_alphas_v = np.log(alphas_v)
# 计算 log(α_t * α_t-1 .... α_0), 相当于α_t 的累积
log_alphas_cumprod_v = np.cumsum(log_alphas_v)
self.log_alphas_v = to_torch_const(log_alphas_v)
# 计算 log(1-α_t)
self.log_one_minus_alphas_v = to_torch_const(log_1_min_a(log_alphas_v))
# 计算 log((1-α_t) * (1-α_t-1) .... (1-α_0)), 相当于1-α_t 的累积
self.log_alphas_cumprod_v = to_torch_const(log_alphas_cumprod_v)
self.log_one_minus_alphas_cumprod_v = to_torch_const(log_1_min_a(log_alphas_cumprod_v))
# 设置β时的权重。只有在 config.sample_time_method == inprotance 时成立
self.register_buffer('Lt_history', torch.zeros(self.num_timesteps))
self.register_buffer('Lt_count', torch.zeros(self.num_timesteps))
注:坐标和节点相关的噪音调度器中,有大量的参数是 DockGuideNet3D 梯度引导模型不会使用的,但是在 GitHub 中仍然保留了,这些属于多余的代码。
(4) 神经网络的配置。包括:隐藏层的维度、小分子的节点类型种类数、是否包含小分子/蛋白的mask、蛋白的嵌入层、小分子和时间步 t 的合并嵌入、get_refine_net 函数获取并实例化神经网络、推理 binding affinity 层、辅助断言(在开发模型测试使用,此处可忽略)。
# model definition 神经网络的定义
self.hidden_dim = config.hidden_dim # 隐藏层维度
self.num_classes = ligand_atom_feature_dim # 小分子节点类型数
if self.config.node_indicator:
# 如果包含小分子/蛋白的标记mask
emb_dim = self.hidden_dim - 1
else:
emb_dim = self.hidden_dim
# atom embedding 原子 embeding 层
self.protein_atom_emb = nn.Linear(protein_atom_feature_dim, emb_dim)
# center pos 坐标中心
self.center_pos_mode = config.center_pos_mode # ['none', 'protein']
# time embedding 时间嵌入层,与 小分子节点特征 embeding 层
self.time_emb_dim = config.time_emb_dim
self.time_emb_mode = config.time_emb_mode # ['simple', 'sin']
if self.time_emb_dim > 0:
if self.time_emb_mode == 'simple':
self.ligand_atom_emb = nn.Linear(ligand_atom_feature_dim + 1, emb_dim)
elif self.time_emb_mode == 'sin':
self.time_emb = nn.Sequential(
SinusoidalPosEmb(self.time_emb_dim),
nn.Linear(self.time_emb_dim, self.time_emb_dim * 4),
nn.GELU(),
nn.Linear(self.time_emb_dim * 4, self.time_emb_dim)
)
self.ligand_atom_emb = nn.Linear(ligand_atom_feature_dim + self.time_emb_dim, emb_dim)
else:
raise NotImplementedError
else:
self.ligand_atom_emb = nn.Linear(ligand_atom_feature_dim, emb_dim)
# 实例化神经网络
self.refine_net_type = config.model_type
self.refine_net = get_refine_net(self.refine_net_type, config)
# 推理 binding affinity 层
self.dock_inference = nn.Sequential(
nn.Linear(self.hidden_dim, self.hidden_dim),
ShiftedSoftplus(),
nn.Linear(self.hidden_dim, 1),
)
# TODO: show warning if drop_protein_in_guide not in config
self.drop_protein_in_guide = config.get("drop_protein_in_guide", False)
assert isinstance(self.drop_protein_in_guide, bool), f"`drop_protein_in_guide` should be True or False. Got: {self.drop_protein_in_guide}"
# TODO: show warning if maximize_property not in config
self.maximize_property = config.get("maximize_property", False)
assert isinstance(self.maximize_property, bool), f"`maximize_property` should be True or False. Got: {self.maximize_property}"
# 推理 binding affinity 层的输出类型
self.problem_type = config.get("problem_type", "regression")
assert self.problem_type in ["regression", "classification"], f"`problem_type` should be 'regression' or 'classification'. Got: {self.problem_type}"
关于 get_refine_net 函数,实例化一个神经网路。这个神经网络名为 egnn, 按照配置文件的定义,包含了9层,维度为128,距离的高斯维度为20的 EGNN 网络。代码如下:
def get_refine_net(refine_net_type, config):
if refine_net_type == 'uni_o2':
raise NotImplementedError(refine_net_type)
elif refine_net_type == 'egnn':
refine_net = EGNN(
num_layers=config.num_layers,
hidden_dim=config.hidden_dim,
edge_feat_dim=config.edge_feat_dim,
num_r_gaussian=config.num_r_gaussian,
k=config.knn,
cutoff=config.r_max,
update_x=config.update_x
)
else:
raise ValueError(refine_net_type)
return refine_net
需要特别注意的是,DockGuideNet3D 的坐标和节点类型的噪音调度器与主模型 ScorePosNet3D 完全一致。完整的 __init__ 代码如下:
def __init__(self, config, protein_atom_feature_dim, ligand_atom_feature_dim):
super().__init__()
self.config = config
# variance schedule
# 神经网络预测 噪音还是真实值,默认是 C0 真实值
self.model_mean_type = config.model_mean_type # ['noise', 'C0']
self.loss_v_weight = config.loss_v_weight
# 时间 t 的 采样,重要性采样 importance 或者 对称性采样 symmetric
self.sample_time_method = config.sample_time_method # ['importance', 'symmetric']
# 噪音调度器 β 每一步添加噪音的比例; α (即,1-β) 每一步保留原信息的比例
if config.beta_schedule == 'cosine':
alphas = cosine_beta_schedule(config.num_diffusion_timesteps, config.pos_beta_s) ** 2
# print('cosine pos alpha schedule applied!')
betas = 1. - alphas
else:
betas = get_beta_schedule(
beta_schedule=config.beta_schedule,
beta_start=config.beta_start,
beta_end=config.beta_end,
num_diffusion_timesteps=config.num_diffusion_timesteps,
)
alphas = 1. - betas
alphas_cumprod = np.cumprod(alphas, axis=0) # α的累乘
alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) # α的累乘,从t=0开始
self.betas = to_torch_const(betas)
self.num_timesteps = self.betas.size(0)
self.alphas_cumprod = to_torch_const(alphas_cumprod) # α的累乘
self.alphas_cumprod_prev = to_torch_const(alphas_cumprod_prev) # α的累乘,从t=0开始
# calculations for diffusion q(x_t | x_{t-1}) and others,添加噪音过程
# 从 x_0 到 x_t 的过程中, x_0 的比例
self.sqrt_alphas_cumprod = to_torch_const(np.sqrt(alphas_cumprod)) # 根号下 α的累乘,从t=0开始
# 从 x_0 到 x_t 的过程中, 噪音的比例
self.sqrt_one_minus_alphas_cumprod = to_torch_const(np.sqrt(1. - alphas_cumprod)) # 根号下(1-α的累乘)
self.sqrt_recip_alphas_cumprod = to_torch_const(np.sqrt(1. / alphas_cumprod))
self.sqrt_recipm1_alphas_cumprod = to_torch_const(np.sqrt(1. / alphas_cumprod - 1))
# calculations for posterior q(x_{t-1} | x_t, x_0)
# 每个时刻 t 下的 x_t 的概率(不确定性/方差)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
# 计算 x_{t-1} 的 x_0 的系数
self.posterior_mean_c0_coef = to_torch_const(betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
# 计算 x_{t-1} 的 x_t 的系数
self.posterior_mean_ct_coef = to_torch_const(
(1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))
# log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
self.posterior_var = to_torch_const(posterior_variance) # 每个时刻 t 下的 x_t 的概率(不确定性/方差)
self.posterior_logvar = to_torch_const(np.log(np.append(self.posterior_var[1], self.posterior_var[1:]))) # 对数方差
# atom type diffusion schedule in log space
# # 节点类型的噪音调度器
# 计算 α_t
if config.v_beta_schedule == 'cosine':
alphas_v = cosine_beta_schedule(self.num_timesteps, config.v_beta_s)
# print('cosine v alpha schedule applied!')
else:
raise NotImplementedError
# 计算 log(α_t)
log_alphas_v = np.log(alphas_v)
# 计算 log(α_t * α_t-1 .... α_0), 相当于α_t 的累积
log_alphas_cumprod_v = np.cumsum(log_alphas_v)
self.log_alphas_v = to_torch_const(log_alphas_v)
# 计算 log(1-α_t)
self.log_one_minus_alphas_v = to_torch_const(log_1_min_a(log_alphas_v))
# 计算 log((1-α_t) * (1-α_t-1) .... (1-α_0)), 相当于1-α_t 的累积
self.log_alphas_cumprod_v = to_torch_const(log_alphas_cumprod_v)
self.log_one_minus_alphas_cumprod_v = to_torch_const(log_1_min_a(log_alphas_cumprod_v))
# 设置β时的权重。只有在 config.sample_time_method == inprotance 时成立
self.register_buffer('Lt_history', torch.zeros(self.num_timesteps))
self.register_buffer('Lt_count', torch.zeros(self.num_timesteps))
# model definition 神经网络的定义
self.hidden_dim = config.hidden_dim # 隐藏层维度
self.num_classes = ligand_atom_feature_dim # 小分子节点类型数
if self.config.node_indicator:
# 如果包含小分子/蛋白的标记mask
emb_dim = self.hidden_dim - 1
else:
emb_dim = self.hidden_dim
# atom embedding 原子 embeding 层
self.protein_atom_emb = nn.Linear(protein_atom_feature_dim, emb_dim)
# center pos 坐标中心
self.center_pos_mode = config.center_pos_mode # ['none', 'protein']
# time embedding 时间嵌入层,与 小分子节点特征 embeding 层
self.time_emb_dim = config.time_emb_dim
self.time_emb_mode = config.time_emb_mode # ['simple', 'sin']
if self.time_emb_dim > 0:
if self.time_emb_mode == 'simple':
self.ligand_atom_emb = nn.Linear(ligand_atom_feature_dim + 1, emb_dim)
elif self.time_emb_mode == 'sin':
self.time_emb = nn.Sequential(
SinusoidalPosEmb(self.time_emb_dim),
nn.Linear(self.time_emb_dim, self.time_emb_dim * 4),
nn.GELU(),
nn.Linear(self.time_emb_dim * 4, self.time_emb_dim)
)
self.ligand_atom_emb = nn.Linear(ligand_atom_feature_dim + self.time_emb_dim, emb_dim)
else:
raise NotImplementedError
else:
self.ligand_atom_emb = nn.Linear(ligand_atom_feature_dim, emb_dim)
# 实例化神经网络
self.refine_net_type = config.model_type
self.refine_net = get_refine_net(self.refine_net_type, config)
# 推理 binding affinity 层
self.dock_inference = nn.Sequential(
nn.Linear(self.hidden_dim, self.hidden_dim),
ShiftedSoftplus(),
nn.Linear(self.hidden_dim, 1),
)
# TODO: show warning if drop_protein_in_guide not in config
self.drop_protein_in_guide = config.get("drop_protein_in_guide", False)
assert isinstance(self.drop_protein_in_guide, bool), f"`drop_protein_in_guide` should be True or False. Got: {self.drop_protein_in_guide}"
# TODO: show warning if maximize_property not in config
self.maximize_property = config.get("maximize_property", False)
assert isinstance(self.maximize_property, bool), f"`maximize_property` should be True or False. Got: {self.maximize_property}"
# 推理 binding affinity 层的输出类型
self.problem_type = config.get("problem_type", "regression")
assert self.problem_type in ["regression", "classification"], f"`problem_type` should be 'regression' or 'classification'. Got: {self.problem_type}"
5.2 get_loss
训练梯度引导模型中,DockGuideNet3D 的 get_loss 方法计算梯度引导模型的损失。与主模型的 get_diffusion_loss 很像。
(1)先是基本设置,
# 批次中图数量
num_graphs = batch_protein.max().item() + 1
# 去坐标中心
protein_pos, ligand_pos, _ = center_pos(
protein_pos, ligand_pos, batch_protein, batch_ligand, mode=self.center_pos_mode)
# 1. sample noise levels
if time_step is None:
# 采样时间 t 及其概率
time_step, pt = self.sample_time(num_graphs, protein_pos.device, self.sample_time_method)
else:
# 时间 t 固定参数输入,只计算概率
pt = torch.ones_like(time_step).float() / self.num_timesteps
## precomputed beforehand to save computational time. Here it is only indexed based on the time_step
a = self.alphas_cumprod.index_select(0, time_step) # (num_graphs, ) # α 累乘
(2)扰动小分子的坐标和节点类型;
# 2. perturb pos and v
a_pos = a[batch_ligand].unsqueeze(-1) # (num_ligand_atoms, 1) # α累乘 扩展
## sampling a normal distribution to add as noise
# 初始化/采样噪音
pos_noise = torch.zeros_like(ligand_pos)
pos_noise.normal_()
## update the coordinates
# Xt = a.sqrt() * X0 + (1-a).sqrt() * eps 扰动小分子的坐标
ligand_pos_perturbed = a_pos.sqrt() * ligand_pos + (1.0 - a_pos).sqrt() * pos_noise # pos_noise * std
## update the categories 扰动原子类型
# Vt = a * V0 + (1-a) / K
log_ligand_v0 = index_to_log_onehot(ligand_v, self.num_classes) # 转为 onr-hot 类型,然后取对数,即,生成节点类型的对数概率
## move some of the probablity mass to other indexes
# 采样扰动后的节点类型索引,及其对数概率
ligand_v_perturbed, log_ligand_vt = self.q_v_sample(log_ligand_v0, time_step, batch_ligand)
其中,index_to_log_onehot 和 self.q_v_sample 等一些支持函数与主模型的完全相同,这里就不再重复,把代码直接贴出来。节点类型和坐标在被扰动的时候,都遵循上述提及的公式:
index_to_log_onehot 和 self.q_v_sample 等支持函数的代码如下:
def index_to_log_onehot(x, num_classes):
'''
索引转化为对数概率
'''
assert x.max().item() < num_classes, f'Error: {x.max().item()} >= {num_classes}'
x_onehot = F.one_hot(x, num_classes)
# permute_order = (0, -1) + tuple(range(1, len(x.size())))
# x_onehot = x_onehot.permute(permute_order)
log_x = torch.log(x_onehot.float().clamp(min=1e-30))
return log_x
def q_v_sample(self, log_v0, t, batch):
'''
基于概率,采样节点类型
返回 节点类型的索引及其对数概率
'''
# 扰动,生成 v_t (对数概率)
log_qvt_v0 = self.q_v_pred(log_v0, t, batch)
# 基于概率,随机采样节点类型概率,返回节点类型索引
sample_index = log_sample_categorical(log_qvt_v0)
# 将节点类型转化为 对数概率, 即:log(one-hot)
log_sample = index_to_log_onehot(sample_index, self.num_classes)
return sample_index, log_sample
def q_v_pred(self, log_v0, t, batch):
'''
计算 q(V_t | V_0),扰动节点类型
'''
# compute q(vt | v0)
log_cumprod_alpha_t = extract(self.log_alphas_cumprod_v, t, batch)
log_1_min_cumprod_alpha = extract(self.log_one_minus_alphas_cumprod_v, t, batch)
log_probs = log_add_exp(
log_v0 + log_cumprod_alpha_t,
log_1_min_cumprod_alpha - np.log(self.num_classes)
)
return log_probs
def log_sample_categorical(logits):
uniform = torch.rand_like(logits) # 按照对数概率,采样节点类型的’随机性‘
gumbel_noise = -torch.log(-torch.log(uniform + 1e-30) + 1e-30) # 随机性进行对数化
# 随机性与概率合并,返回最大概率的节点类型(不是 one-hot)
sample_index = (gumbel_noise + logits).argmax(dim=-1)
# sample_onehot = F.one_hot(sample, self.num_classes)
# log_sample = index_to_log_onehot(sample, self.num_classes)
return sample_index
(3)输入到 forword 函数中计算预测的 Binding Affinity
# 3. forward-pass NN, feed perturbed pos and v, output noise
# 将 蛋白坐标和节点类型,扰动后的小分子坐标和节点类型,以及时间步 t 输入到 forward 中
# 输出预测的 binding affinity
preds = self(
protein_pos=protein_pos,
protein_atom_feature=protein_v,
batch_protein=batch_protein,
ligand_pos=ligand_pos_perturbed,
ligand_atom_feature=F.one_hot(ligand_v_perturbed, self.num_classes), # one-hot
batch_ligand=batch_ligand,
time_step=time_step,
fix_x=True
)
关于 forward 函数,主要包括:
(3.1) 时间步 t 的嵌入,然后与小分子节点类型 contact 合并成新的小分子节点类型;
if self.time_emb_dim > 0:
if self.time_emb_mode == 'simple':
input_ligand_feat = torch.cat([
ligand_atom_feature,
(time_step / self.num_timesteps)[batch_ligand].unsqueeze(-1)
], -1)
elif self.time_emb_mode == 'sin':
time_feat = self.time_emb(time_step)
time_feat = time_feat[batch_ligand]
input_ligand_feat = torch.cat([ligand_atom_feature, time_feat], -1)
else:
raise NotImplementedError
else:
input_ligand_feat = ligand_atom_feature
(3.2) 小分子节点特征的标记与嵌入
# 新的小分子节点类型 嵌入
init_ligand_h = self.ligand_atom_emb(input_ligand_feat)
# 小分子节点的标记
if self.config.node_indicator:
init_ligand_h = torch.cat([init_ligand_h, torch.ones(len(init_ligand_h), 1).to(init_ligand_h.device)], -1)
(3.3) 蛋白节点的标记合并蛋白和小分子的坐标以及节点类型
if self.drop_protein_in_guide is True:
# 如果 在梯度计算中不包含 蛋白,比如:预测logP
h_all, pos_all, batch_all = init_ligand_h, ligand_pos, batch_ligand
# TODO: check `mask_ligand`
mask_ligand = torch.ones([batch_ligand.size(0)], device=batch_ligand.device).bool()
else:
# 在预测中需要包含蛋白
# 蛋白节点嵌入
h_protein = self.protein_atom_emb(protein_atom_feature)
# 蛋白节点标记
if self.config.node_indicator:
h_protein = torch.cat([h_protein, torch.zeros(len(h_protein), 1).to(h_protein.device)], -1)
# 合并蛋白和小分子节点类型和坐标
h_all, pos_all, batch_all, mask_ligand = compose_context(
h_protein=h_protein,
h_ligand=init_ligand_h,
pos_protein=protein_pos,
pos_ligand=ligand_pos,
batch_protein=batch_protein,
batch_ligand=batch_ligand,
)
关于 compose_context 函数,也与主函数完全一致,如下:
def compose_context(h_protein, h_ligand, pos_protein, pos_ligand, batch_protein, batch_ligand):
'''
蛋白质 (protein) 和配体 (ligand) 的嵌入信息、位置和批次信息进行组合,并返回合并后的上下文表示
'''
# previous version has problems when ligand atom types are fixed
# (due to sorting randomly in case of same element)
batch_ctx = torch.cat([batch_protein, batch_ligand], dim=0) # mask 批次信息 图的序号
# sort_idx = batch_ctx.argsort() 按照批次信息进行排序的索引
sort_idx = torch.sort(batch_ctx, stable=True).indices
# mask 用于区分蛋白和小分子的掩码,按照图的序号重拍后
mask_ligand = torch.cat([
torch.zeros([batch_protein.size(0)], device=batch_protein.device).bool(),
torch.ones([batch_ligand.size(0)], device=batch_ligand.device).bool(),
], dim=0)[sort_idx]
batch_ctx = batch_ctx[sort_idx] # 重拍后的批次信息
# 重拍后的节点特征
h_ctx = torch.cat([h_protein, h_ligand], dim=0)[sort_idx] # (N_protein+N_ligand, H)
# 重拍后的坐标
pos_ctx = torch.cat([pos_protein, pos_ligand], dim=0)[sort_idx] # (N_protein+N_ligand, 3)
return h_ctx, pos_ctx, batch_ctx, mask_ligand
(3.4) 使用 EGNN 以及推理 binding affinity 层,预测并返回 Binding Affinity 的值
## get the hidden states from GNN EGNN 计算预测值(维度是:(atom_number, ))
outputs = self.refine_net(h_all, pos_all, mask_ligand, batch_all, return_all=return_all, fix_x=fix_x)
if fix_x:
assert torch.all(torch.eq(outputs['x'], pos_all))
## aggregate/pool the hidden states 整合输出 (graph_num,)
aggregate_output = scatter(outputs['h'], index=batch_all, dim=0, reduce='sum')
## get the binding energy (graph_num,)
# 推理 binding affinity 层
output = self.dock_inference(aggregate_output)
return output
完整的 forward hanshu 如下:
def forward(self, protein_pos, protein_atom_feature, ligand_pos, ligand_atom_feature, batch_protein, batch_ligand,
time_step, return_all=False, fix_x=False):
'''
基于 蛋白的坐标和节点类型、扰动后的小分子坐标和节点类型、时间步t,
以及批次mask(batch_protein, batch_ligand),
输出 预测的 binding affinity(也可以是其他)
'''
batch_size = batch_protein.max().item() + 1
# time embedding
## added to inform the model about the extent of noise that has been added
# 时间步 t 的嵌入,然后与小分子节点类型 contact 合并成新的小分子节点类型
if self.time_emb_dim > 0:
if self.time_emb_mode == 'simple':
input_ligand_feat = torch.cat([
ligand_atom_feature,
(time_step / self.num_timesteps)[batch_ligand].unsqueeze(-1)
], -1)
elif self.time_emb_mode == 'sin':
time_feat = self.time_emb(time_step)
time_feat = time_feat[batch_ligand]
input_ligand_feat = torch.cat([ligand_atom_feature, time_feat], -1)
else:
raise NotImplementedError
else:
input_ligand_feat = ligand_atom_feature
# 新的小分子节点类型 嵌入
init_ligand_h = self.ligand_atom_emb(input_ligand_feat)
# 小分子节点的标记
if self.config.node_indicator:
init_ligand_h = torch.cat([init_ligand_h, torch.ones(len(init_ligand_h), 1).to(init_ligand_h.device)], -1)
if self.drop_protein_in_guide is True:
# 如果 在梯度计算中不包含 蛋白,比如:预测logP
h_all, pos_all, batch_all = init_ligand_h, ligand_pos, batch_ligand
# TODO: check `mask_ligand`
mask_ligand = torch.ones([batch_ligand.size(0)], device=batch_ligand.device).bool()
else:
# 在预测中需要包含蛋白
# 蛋白节点嵌入
h_protein = self.protein_atom_emb(protein_atom_feature)
# 蛋白节点标记
if self.config.node_indicator:
h_protein = torch.cat([h_protein, torch.zeros(len(h_protein), 1).to(h_protein.device)], -1)
# 合并蛋白和小分子节点类型和坐标
h_all, pos_all, batch_all, mask_ligand = compose_context(
h_protein=h_protein,
h_ligand=init_ligand_h,
pos_protein=protein_pos,
pos_ligand=ligand_pos,
batch_protein=batch_protein,
batch_ligand=batch_ligand,
)
## get the hidden states from GNN EGNN 计算预测值(维度是:(atom_number, ))
outputs = self.refine_net(h_all, pos_all, mask_ligand, batch_all, return_all=return_all, fix_x=fix_x)
if fix_x:
assert torch.all(torch.eq(outputs['x'], pos_all))
## aggregate/pool the hidden states 整合输出 (graph_num,)
aggregate_output = scatter(outputs['h'], index=batch_all, dim=0, reduce='sum')
## get the binding energy (graph_num,)
# 推理 binding affinity 层
output = self.dock_inference(aggregate_output)
return output
(4)随后计算预测的 Binding Affinity 与真实值 dock 之间的损失;
if self.problem_type == "regression":
# 如果梯度引导模型是 回归模式
loss_func = nn.MSELoss()
loss = loss_func(preds.view(-1), dock)
elif self.problem_type == "classification":
# 分类模式
loss_func = nn.BCEWithLogitsLoss()
loss = loss_func(preds.view(-1), dock.float())
else:
raise ValueError(f"Unknown problem type: {self.problem_type}")
(5)返回损失及其预测值。
# 返回损失和预测值
if return_pred:
return loss, preds
else:
return loss
get_loss 的完整代码如下:
def get_loss(
self, protein_pos, protein_v, batch_protein, ligand_pos, ligand_v, batch_ligand, dock, time_step=None, return_pred=False
):
# 批次中图数量
num_graphs = batch_protein.max().item() + 1
# 去坐标中心
protein_pos, ligand_pos, _ = center_pos(
protein_pos, ligand_pos, batch_protein, batch_ligand, mode=self.center_pos_mode)
# 1. sample noise levels
if time_step is None:
# 采样时间 t 及其概率
time_step, pt = self.sample_time(num_graphs, protein_pos.device, self.sample_time_method)
else:
# 时间 t 固定参数输入,只计算概率
pt = torch.ones_like(time_step).float() / self.num_timesteps
## precomputed beforehand to save computational time. Here it is only indexed based on the time_step
a = self.alphas_cumprod.index_select(0, time_step) # (num_graphs, ) # α 累乘
# 2. perturb pos and v
a_pos = a[batch_ligand].unsqueeze(-1) # (num_ligand_atoms, 1) # α累乘 扩展
## sampling a normal distribution to add as noise
# 初始化/采样噪音
pos_noise = torch.zeros_like(ligand_pos)
pos_noise.normal_()
## update the coordinates
# Xt = a.sqrt() * X0 + (1-a).sqrt() * eps 扰动小分子的坐标
ligand_pos_perturbed = a_pos.sqrt() * ligand_pos + (1.0 - a_pos).sqrt() * pos_noise # pos_noise * std
## update the categories 扰动原子类型
# Vt = a * V0 + (1-a) / K
log_ligand_v0 = index_to_log_onehot(ligand_v, self.num_classes) # 转为 onr-hot 类型,然后取对数,即,生成节点类型的对数概率
## move some of the probablity mass to other indexes
# 采样扰动后的节点类型索引,及其对数概率
ligand_v_perturbed, log_ligand_vt = self.q_v_sample(log_ligand_v0, time_step, batch_ligand)
# 3. forward-pass NN, feed perturbed pos and v, output noise
# 将 蛋白坐标和节点类型,扰动后的小分子坐标和节点类型,以及时间步 t 输入到 forward 中
# 输出预测的 binding affinity
preds = self(
protein_pos=protein_pos,
protein_atom_feature=protein_v,
batch_protein=batch_protein,
ligand_pos=ligand_pos_perturbed,
ligand_atom_feature=F.one_hot(ligand_v_perturbed, self.num_classes), # one-hot
batch_ligand=batch_ligand,
time_step=time_step,
fix_x=True
)
if self.problem_type == "regression":
# 如果梯度引导模型是 回归模式
loss_func = nn.MSELoss()
loss = loss_func(preds.view(-1), dock)
elif self.problem_type == "classification":
# 分类模式
loss_func = nn.BCEWithLogitsLoss()
loss = loss_func(preds.view(-1), dock.float())
else:
raise ValueError(f"Unknown problem type: {self.problem_type}")
# 返回损失和预测值
if return_pred:
return loss, preds
else:
return loss
get_loss 函数与第三部分组成了一个完整的梯度引导模型的训练过程,至此,训练一个梯度引导模型相关的部分就全部完成。
5.3 get_gradients_guide
get_gradients_guide 是 DockGuideNet3D 非常重要的一个函数,用于在分子生成的每一个时间 t,小分子的 v_t 和 x_t 组成的分子,与目标属性之间的差距。 在分子生成的 2.3 梯度引导的分子分生成中,调用了 DockGuideNet3D 的 get_gradients_guide 方法来计算梯度,在主模型的 sample_multi_guided_diffusion 方法中被引用计算梯度,返回目标值(Bindding Affinity)对小分子坐标的梯度。
(1)设置评估模式,梯度清零
self.eval() # 模型评估模式 (不启用 dropout 和 batch normalization 等训练时特有的操作)
self.zero_grad() # 梯度清零
(2)启用梯度计算,将坐标和节点设置为需要梯度,调用 forward 函数预测对应的 Binding Affinity 结果,对预测值进行剪裁,并设置预测值是越大越好(self.maximize_property=True)还是越小越好(self.maximize_property=False),计算坐标和节点类型的梯度,返回坐标梯度。(文章中未涉及节点类型的梯度)
# 启动梯度计算
with torch.enable_grad(): ## needed during inference
ligand_pos = ligand_pos.detach().requires_grad_(True) # 设置坐标梯度
if not pos_only:# 设置节点类型梯度
ligand_atom_feature = ligand_atom_feature.detach().requires_grad_(True)
# 调用 forward 函数计算预测的 binding affinity
pred = self(
protein_pos=protein_pos,
protein_atom_feature=protein_atom_feature.float(),
ligand_pos=ligand_pos,
ligand_atom_feature=ligand_atom_feature.float(),
batch_protein=batch_protein,
batch_ligand=batch_ligand,
time_step=time_step,
fix_x=True
)
# 预测结果剪裁
if self.problem_type not in ("classification", ):
if clamp_pred_min is not None or clamp_pred_max is not None:
pred = torch.clamp(pred, min=clamp_pred_min, max=clamp_pred_max)
else:
raise NotImplementedError(f"Not implemented for {self.problem_type} problem type")
# 如果优化目标是值越大越好
if self.maximize_property:
# maximize pred => minimize -pred`
pred = -pred
# pred.mean().backward()
## must reduce the binding energies (Vina Dock) - take the mean of pred across the batch
## "ligand_pos_grad" is in increasing direction
## must subtract it at the end
# 计算梯度
ligand_pos_grad = torch.autograd.grad(pred.sum(), ligand_pos, retain_graph=True)[0] # 坐标梯度
if not pos_only: # 节点类型梯度
ligand_atom_feature_grad = torch.autograd.grad(pred.sum(), ligand_atom_feature, retain_graph=True)[0]
return ligand_pos_grad, ligand_atom_feature_grad
# 返回坐标梯度
return ligand_pos_grad
get_gradients_guide 的 完整代码如下:
def get_gradients_guide(self, protein_pos, protein_atom_feature, ligand_pos, ligand_atom_feature,
batch_protein, batch_ligand, time_step,
pos_only=False,
clamp_pred_min=None, clamp_pred_max=None # 预测值的最大值和最小值剪裁
):
## get the gradients w.r.t ligand position and features using the Binding-affinity EGNN predictor
self.eval() # 模型评估模式 (不启用 dropout 和 batch normalization 等训练时特有的操作)
self.zero_grad() # 梯度清零
# 启动梯度计算
with torch.enable_grad(): ## needed during inference
ligand_pos = ligand_pos.detach().requires_grad_(True) # 设置坐标梯度
if not pos_only:# 设置节点类型梯度
ligand_atom_feature = ligand_atom_feature.detach().requires_grad_(True)
# 调用 forward 函数计算预测的 binding affinity
pred = self(
protein_pos=protein_pos,
protein_atom_feature=protein_atom_feature.float(),
ligand_pos=ligand_pos,
ligand_atom_feature=ligand_atom_feature.float(),
batch_protein=batch_protein,
batch_ligand=batch_ligand,
time_step=time_step,
fix_x=True
)
# 预测结果剪裁
if self.problem_type not in ("classification", ):
if clamp_pred_min is not None or clamp_pred_max is not None:
pred = torch.clamp(pred, min=clamp_pred_min, max=clamp_pred_max)
else:
raise NotImplementedError(f"Not implemented for {self.problem_type} problem type")
# 如果优化目标是值越大越好
if self.maximize_property:
# maximize pred => minimize -pred`
pred = -pred
# pred.mean().backward()
## must reduce the binding energies (Vina Dock) - take the mean of pred across the batch
## "ligand_pos_grad" is in increasing direction
## must subtract it at the end
# 计算梯度
ligand_pos_grad = torch.autograd.grad(pred.sum(), ligand_pos, retain_graph=True)[0] # 坐标梯度
if not pos_only: # 节点类型梯度
ligand_atom_feature_grad = torch.autograd.grad(pred.sum(), ligand_atom_feature, retain_graph=True)[0]
return ligand_pos_grad, ligand_atom_feature_grad
# 返回坐标梯度
return ligand_pos_grad
六、 ScorePosNet3D 中的等变神经网络
主模型 ScorePosNet3D 中可以选择两种等变网络: EGNN 和 uni_o2,作者在文章中和代码中使用的都是 EGNN。 梯度引导模型 DockGuideNet3D 使用的也是 EGNN。
关于 uni_o2 作者提供了代码,但是没有相应的介绍,不知道是否符合 SE3等变。因此这里暂时不介绍。
七、总结
TagMol 是一个梯度引导的分子生成扩散模型,作者提供了完整的项目代码。
通过对这个项目的代码详细解析,可以仔细理解分子生成的扩散模型的原理及其代码架构。
以 TagMol 作为基础模型,可以很快的理解其他基于扩散模型的分子生成模型,可以快速窥探每个模型的特点。