3D分子生成模型 TagMol 代码解析 + 梯度引导扩散模型详解

一、TagMol 背景介绍

TAGMoL 是一个基于分子属性条件引导扩散的 3D 分子生成模型,适合在给定靶标蛋白质的情况下,可以生成一系列满足目标特性(分子属性,binding affinity)的候选分子。

TAGMoL 来源于新德里 Molecule AI, 以及美国马萨诸塞大学曼宁信息与计算机科学学院 的Vineeth Dorna 和 D. Subhalingam 为通讯作者的文章:《TAGMoL : Target Aware Gradient-guided Molecule Generation》。文章链接:https://arxiv.org/abs/2406.01650​​​​​​​ 。该文章于 2024 年 6 月 3 日发表在 arxiv 上。

该模型已经经过了评测,详见 《分子属性梯度引导的3D分子生成扩散模型 TAGMOL - 评测》

分子属性梯度引导的3D分子生成扩散模型 TAGMOL - 评测-CSDN博客

TagMol 模型分为两部分,分子生成模型和梯度引导模型。其中,在原文中,分子生成模型使用的是TargetDiff 模型无需训练,拿原作者提供的 checkpoint 直接使用即可,梯度引导模型为自己训练。在分子生成过程中,使用梯度引导模型,修改 TargetDiff 的生成模型过程,实现特定属性的分子生成。因此,在代码解析部分,我们关心分子生成过程,以及如何训练一个属性引导模型,包含训练梯度引导模型需要的数据预处理过程。(TargetDiff 训练过程则跳过,详见 TargetDiff 部分)。

以下为梯度引导模型及其训练代码解析。

注: 因为原 GitHub 的代码存在 bugs,这里解析的是我们之前修改过的代码,与作者在 GitHub 上提供的稍微不一致。代码的运行环境,更多评测内容,也请看之前的文章《分子属性梯度引导的3D分子生成扩散模型 TAGMOL - 评测》。

二、TagMol 的分子生成代码解析

以多目标分子生成为例,其运行命令为:

python scripts/sample_for_pocket_guided.py \
  configs/noise_guide_multi/sampling_guided_qed_0.33_sa_0.33_ba_0.34.yml \
  --pdb_path ./3wze/pocket_3wze.pdb \
  --result_path outputs_TagMol_3wze_path

在上述命令中,调用 configs/noise_guide_multi/sampling_guided_qed_0.33_sa_0.33_ba_0.34.yml 配置文件,输入的蛋白是 ./3wze/pocket_3wze.pdb, 分子生成结果保存至 outputs_TagMol_3wze_path。

配置文件内容如下,其中,设置了基础模型的 checkpoint,多个引导模型的 checkpoint,及其权重、梯度、类型设置,最后是设置了分子生成采样(采样数量,去噪步数,质量中心等):

# 基础模型
model:
  checkpoint: ./pretrained_models/pretrained_diffusion.pt

# 引导模型
guide_models:
  - name: qed
    checkpoint: ./logs/training_dock_guide_qed_2024_01_06__01_35_21/checkpoints/186000.pt
    weight: 0.33
    guide_kind: Kd
    gradient_scale_cord: 20
    gradient_scale_categ: 0.0 #1e-10
    clamp_pred_max: 1.0
  - name: sa
    checkpoint: ./logs/training_dock_guide_sa_2024_01_20__15_38_49/checkpoints/162000.pt
    weight: 0.33
    guide_kind: Kd
    gradient_scale_cord: 5
    gradient_scale_categ: 0.0 #1e-10
    clamp_pred_max: 1.0
  - name: binding_affinity
    checkpoint: ./logs/training_dock_guide_2023_12_17__06_23_35/checkpoints/184000.pt
    weight: 0.34
    guide_kind: Kd
    gradient_scale_cord: 2.0
    gradient_scale_categ: 0.0 #1e-10

# 采样设置
sample:
  seed: 2021
  num_samples: 100
  num_steps: 1000
  pos_only: False
  center_pos_mode: protein
  sample_num_atoms: prior

因此,我们首先看看 scripts/sample_for_pocket_guided.py 文件的代码。

2.1 __main__ 函数 sample_for_pocket_guided()

首先是 __init__.py 函数。

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('config', type=str)
    parser.add_argument('--pdb_path', type=str) # 口袋/蛋白路径
    parser.add_argument('--device', type=str, default='cuda:0') # GPU
    parser.add_argument('--batch_size', type=int, default=50) # 批次大小
    parser.add_argument('--result_path', type=str, default='./outputs_guided_pdb') # 结果保存路径
    parser.add_argument('--num_samples', type=int) # 采样分子数
    args = parser.parse_args()

    # 分子生成
    sample_for_pocket_guided(config_path=args.config,
                        pdb_path=args.pdb_path, 
                        device=args.device,
                        batch_size=args.batch_size,
                        result_path=args.result_path)

其中,指定了配置文件,口袋/蛋白文件,输出路径,生成分子数。随后,调用 sample_for_pocket_guided 函数进行分子生成。

sample_for_pocket_guided 函数的代码如下:


def sample_for_pocket_guided(config_path, pdb_path, device, batch_size, result_path):
    logger = misc.get_logger('evaluate')

    # Load config 加载配置文件
    config = misc.load_config(config_path)
    logger.info(config)
    misc.seed_all(config.sample.seed)

    # Load checkpoint 加载主模型的 checkpoint
    ckpt = torch.load(config.model.checkpoint, map_location=device)
    logger.info(f"Training Config: {ckpt['config']}")

    # Transforms 主模型数据转换器
    protein_featurizer = trans.FeaturizeProteinAtom()
    ligand_atom_mode = ckpt['config'].data.transform.ligand_atom_mode
    ligand_featurizer = trans.FeaturizeLigandAtom(ligand_atom_mode)
    transform = Compose([
        protein_featurizer,
    ])

    # Load model 加载主模型
    model = ScorePosNet3D(
        ckpt['config'].model,
        protein_atom_feature_dim=protein_featurizer.feature_dim,
        ligand_atom_feature_dim=ligand_featurizer.feature_dim
    ).to(device)
    # 加载模型权重,strict=False,这意味着模型结构和权重不完全匹配时,不会抛出错误,允许跳过一些不匹配的参数。
    model.load_state_dict(ckpt['model'], strict=False if 'train_config' in config.model else True)
    model.eval()
    for param in model.parameters():
        param.requires_grad = False
    logger.info(f'Successfully load the model! {config.model.checkpoint}')

    # Load Guide Checkpoint
    # guide_ckpt = torch.load(config.guide_model.checkpoint, map_location=device)
    # guide_ckpt = torch.load(config.model.checkpoint, map_location=device) # wufeil
    # logger.info(f"Guide Training Config: {guide_ckpt['config']}")

    # Guide Transforms 引导模型的数据转换器
    guide_protein_featurizer = GuideFeaturizeProteinAtom()
    guide_ligand_featurizer = GuideFeaturizeLigandAtom(mode=ckpt['config']['data']['transform']['ligand_atom_mode'])


    # Guide model
    # guide_model = get_guide_model(guide_ckpt['config'], guide_protein_featurizer.feature_dim, guide_ligand_featurizer.feature_dim)
    # guide_model.load_state_dict(guide_ckpt['model'])
    # guide_model = guide_model.to(device)
    # guide_model.eval()
    # for param in guide_model.parameters():
    #     param.requires_grad = False

    # Guide models 多个引导模型加载到 guide_models 中
    ### wufeil #######################################################################
    guide_models = []
    for guide_model_config in config.guide_models:
        # Load Guide Checkpoint
        guide_ckpt = torch.load(guide_model_config.checkpoint, map_location=device)
        print(f"\nGuide Name: {guide_model_config.name}")
        logger.info(f"Guide Name: {guide_model_config.name}")
        print(f"\nGuide Training Config: {guide_ckpt['config']}")
        logger.info(f"Guide Training Config: {guide_ckpt['config']}")

        # Guide model
        guide_model = get_guide_model(guide_ckpt['config'], guide_protein_featurizer.feature_dim, guide_ligand_featurizer.feature_dim)
        guide_model.load_state_dict(guide_ckpt['model'])
        guide_model = guide_model.to(device)
        guide_model.eval()
        guide_models.append(guide_model)

    #####################################################################

    # Load pocket 加载口袋 pdb 文件,并转换为模型输入
    data = pdb_to_pocket_data(pdb_path)
    data = transform(data)

    # 分子生成,获得生成分子的节点,各节点坐标(以及生成过程的轨迹)
    all_pred_pos, all_pred_v, all_pred_pos_traj, all_pred_v_traj, all_pred_v0_traj, all_pred_vt_traj, all_pred_pos0_traj, time_list = sample_guided_diffusion_ligand(
        model=model,
        # guide_model=guide_model,
        guide_models=guide_models,# wufeil
        data=data,
        guide_configs=config.guide_models, # wufeil
        num_samples=config.sample.num_samples,
        # kind=config.sample.guide_kind,
        # gradient_scale_cord=config.sample.gradient_scale_cord,
        # gradient_scale_categ=config.sample.gradient_scale_categ,
        batch_size=batch_size, 
        device=device,
        num_steps=config.sample.num_steps,
        pos_only=config.sample.pos_only,
        center_pos_mode=config.sample.center_pos_mode,
        sample_num_atoms=config.sample.sample_num_atoms
    )

    result = {
        'data': data,
        'pred_ligand_pos': all_pred_pos,
        'pred_ligand_v': all_pred_v,
        'pred_ligand_pos_traj': all_pred_pos_traj,
        'pred_ligand_v_traj': all_pred_v_traj,
        'pred_ligand_v0_traj': all_pred_v0_traj,
        'pred_ligand_vt_traj': all_pred_vt_traj,
        'pred_pos0_traj': all_pred_pos0_traj
    }
    logger.info('Sample done!')

    # reconstruction 重构分子,并统计重构成功分子数量
    gen_mols = []
    n_recon_success, n_complete = 0, 0
    for sample_idx, (pred_pos, pred_v) in enumerate(zip(all_pred_pos, all_pred_v)):
        pred_atom_type = trans.get_atomic_number_from_index(pred_v, mode='add_aromatic')
        try:
            pred_aromatic = trans.is_aromatic_from_index(pred_v, mode='add_aromatic')
            mol = reconstruct.reconstruct_from_generated(pred_pos, pred_atom_type, pred_aromatic)
            smiles = Chem.MolToSmiles(mol)
        except reconstruct.MolReconsError:
            gen_mols.append(None)
            continue
        n_recon_success += 1

        if '.' in smiles:
            gen_mols.append(None)
            continue
        n_complete += 1
        gen_mols.append(mol)
    result['mols'] = gen_mols
    logger.info('Reconstruction done!')
    logger.info(f'n recon: {n_recon_success} n complete: {n_complete}')

    # 创建结果路径,并复制保存 配置文件及生成的小分子
    os.makedirs(result_path, exist_ok=True)
    shutil.copyfile(config_path, os.path.join(result_path, 'sample.yml'))
    torch.save(result, os.path.join(result_path, f'sample.pt'))
    mols_save_path = os.path.join(result_path, f'sdf')
    os.makedirs(mols_save_path, exist_ok=True)
    for idx, mol in enumerate(gen_mols):
        if mol is not None:
            sdf_writer = Chem.SDWriter(os.path.join(mols_save_path, f'{idx:03d}.sdf'))
            sdf_writer.write(mol)
            sdf_writer.close()
    logger.info(f'Results are saved in {result_path}')
    misc.close_logger(logger)

sample_for_pocket_guided 函数内容分为以下几步:

(1) 加载配置文件;

(2) 加载主模型(ScorePosNet3D,即 TargetDiff 模型) 数据转换器trans.FeaturizeProteinAtom() 及其model(即,TargetDiff);

(3) 加载引导模型的数据转换器 GuideFeaturizeProteinAtom 及多个属性引导模型 guide_models;

(4) pdb_to_pocket_data 函数加载口袋 pdb 文件得到 data;

(5) 调用 sample_guided_diffusion_ligand 函数传入 model 和 guide_models 以及生成配置,进行分子生成。

2.2 加载口袋 pdb 文件(pdb_to_pocket_data)

pdb_to_pocket_data 函数,传入一个口袋 pdb 文件,返回输入 TagMol 模型前的庶数据格式(字典)。将口袋 pdb 文件使用 PDBProtein 类中的 to_dict_atom 方法,按照原子进行提取,提取成为一个字典 pocket_dict,torchify_dict 转化为pytorch tensor。同时,设置一个空的小分子字典 ligand_dict,与 pocket_dict 通过 ProteinLigandData 类的 from_protein_ligand_dicts 方法组合成 data。代码如下:

def pdb_to_pocket_data(pdb_path):
    pocket_dict = PDBProtein(pdb_path).to_dict_atom()
    data = ProteinLigandData.from_protein_ligand_dicts(
        protein_dict=torchify_dict(pocket_dict),
        ligand_dict={
            'element': torch.empty([0, ], dtype=torch.long),
            'pos': torch.empty([0, 3], dtype=torch.float),
            'atom_feature': torch.empty([0, 8], dtype=torch.float),
            'bond_index': torch.empty([2, 0], dtype=torch.long),
            'bond_type': torch.empty([0, ], dtype=torch.long),
        }
    )

    return data

torchify_dict 内容比较简单,将字典中,值为 numpy 矩阵的对象转化为 pytorch tensor,如下:

def torchify_dict(data):
    # 将 字典中,值为 numpy 矩阵的对象转化为 pytorch tensor
    output = {}
    for k, v in data.items():
        if isinstance(v, np.ndarray):
            output[k] = torch.from_numpy(v)
        else:
            output[k] = v
    return output

2.2.1 解析 pdb 文件 PDBProtein()

PDBProtein 类逐行解析 pdb 文件内容(_parse 方法)。可以返回,原子级别的信息(包括:元素序号、名字、坐标、是否是主链、元素符号、所属的残基编码)、氨基酸级别信息(包括:氨基酸编码、氨基酸质心、CA, C, N, O 原子的坐标)。此外,还可以输入中心坐标和半径输出半径范围内的氨基酸。

详细代码如下:

class PDBProtein(object):
    AA_NAME_SYM = {
        'ALA': 'A', 'CYS': 'C', 'ASP': 'D', 'GLU': 'E', 'PHE': 'F', 'GLY': 'G', 'HIS': 'H',
        'ILE': 'I', 'LYS': 'K', 'LEU': 'L', 'MET': 'M', 'ASN': 'N', 'PRO': 'P', 'GLN': 'Q',
        'ARG': 'R', 'SER': 'S', 'THR': 'T', 'VAL': 'V', 'TRP': 'W', 'TYR': 'Y',
    }

    AA_NAME_NUMBER = {
        k: i for i, (k, _) in enumerate(AA_NAME_SYM.items())
    }

    BACKBONE_NAMES = ["CA", "C", "N", "O"]

    def __init__(self, data, mode='auto'):
        super().__init__()
        if (data[-4:].lower() == '.pdb' and mode == 'auto') or mode == 'path':
            with open(data, 'r') as f:
                self.block = f.read()
        else:
            self.block = data

        self.ptable = Chem.GetPeriodicTable()

        # Molecule properties
        self.title = None
        # Atom properties
        self.atoms = []
        self.element = []
        self.atomic_weight = []
        self.pos = []
        self.atom_name = []
        self.is_backbone = []
        self.atom_to_aa_type = []
        # Residue properties
        self.residues = []
        self.amino_acid = []
        self.center_of_mass = []
        self.pos_CA = []
        self.pos_C = []
        self.pos_N = []
        self.pos_O = []

        self._parse()

    def _enum_formatted_atom_lines(self):
        for line in self.block.splitlines():
            if line[0:6].strip() == 'ATOM':
                element_symb = line[76:78].strip().capitalize()
                if len(element_symb) == 0:
                    element_symb = line[13:14]
                yield {
                    'line': line,
                    'type': 'ATOM',
                    'atom_id': int(line[6:11]),
                    'atom_name': line[12:16].strip(),
                    'res_name': line[17:20].strip(),
                    'chain': line[21:22].strip(),
                    'res_id': int(line[22:26]),
                    'res_insert_id': line[26:27].strip(),
                    'x': float(line[30:38]),
                    'y': float(line[38:46]),
                    'z': float(line[46:54]),
                    'occupancy': float(line[54:60]),
                    'segment': line[72:76].strip(),
                    'element_symb': element_symb,
                    'charge': line[78:80].strip(),
                }
            elif line[0:6].strip() == 'HEADER':
                yield {
                    'type': 'HEADER',
                    'value': line[10:].strip()
                }
            elif line[0:6].strip() == 'ENDMDL':
                break  # Some PDBs have more than 1 model.

    def _parse(self):
        # Process atoms
        residues_tmp = {}
        for atom in self._enum_formatted_atom_lines():
            if atom['type'] == 'HEADER':
                self.title = atom['value'].lower()
                continue
            self.atoms.append(atom)
            atomic_number = self.ptable.GetAtomicNumber(atom['element_symb'])
            next_ptr = len(self.element)
            self.element.append(atomic_number)
            self.atomic_weight.append(self.ptable.GetAtomicWeight(atomic_number))
            self.pos.append(np.array([atom['x'], atom['y'], atom['z']], dtype=np.float32))
            self.atom_name.append(atom['atom_name'])
            self.is_backbone.append(atom['atom_name'] in self.BACKBONE_NAMES)
            self.atom_to_aa_type.append(self.AA_NAME_NUMBER[atom['res_name']])

            chain_res_id = '%s_%s_%d_%s' % (atom['chain'], atom['segment'], atom['res_id'], atom['res_insert_id'])
            if chain_res_id not in residues_tmp:
                residues_tmp[chain_res_id] = {
                    'name': atom['res_name'],
                    'atoms': [next_ptr],
                    'chain': atom['chain'],
                    'segment': atom['segment'],
                }
            else:
                assert residues_tmp[chain_res_id]['name'] == atom['res_name']
                assert residues_tmp[chain_res_id]['chain'] == atom['chain']
                residues_tmp[chain_res_id]['atoms'].append(next_ptr)

        # Process residues
        self.residues = [r for _, r in residues_tmp.items()]
        for residue in self.residues:
            sum_pos = np.zeros([3], dtype=np.float32)
            sum_mass = 0.0
            for atom_idx in residue['atoms']:
                sum_pos += self.pos[atom_idx] * self.atomic_weight[atom_idx]
                sum_mass += self.atomic_weight[atom_idx]
                if self.atom_name[atom_idx] in self.BACKBONE_NAMES:
                    residue['pos_%s' % self.atom_name[atom_idx]] = self.pos[atom_idx]
            residue['center_of_mass'] = sum_pos / sum_mass

        # Process backbone atoms of residues
        for residue in self.residues:
            self.amino_acid.append(self.AA_NAME_NUMBER[residue['name']])
            self.center_of_mass.append(residue['center_of_mass'])
            for name in self.BACKBONE_NAMES:
                pos_key = 'pos_%s' % name  # pos_CA, pos_C, pos_N, pos_O
                if pos_key in residue:
                    getattr(self, pos_key).append(residue[pos_key])
                else:
                    getattr(self, pos_key).append(residue['center_of_mass'])

    def to_dict_atom(self):
        return {
            'element': np.array(self.element, dtype=np.long),
            'molecule_name': self.title,
            'pos': np.array(self.pos, dtype=np.float32),
            'is_backbone': np.array(self.is_backbone, dtype=np.bool),
            'atom_name': self.atom_name,
            'atom_to_aa_type': np.array(self.atom_to_aa_type, dtype=np.long)
        }

    def to_dict_residue(self):
        return {
            'amino_acid': np.array(self.amino_acid, dtype=np.long),
            'center_of_mass': np.array(self.center_of_mass, dtype=np.float32),
            'pos_CA': np.array(self.pos_CA, dtype=np.float32),
            'pos_C': np.array(self.pos_C, dtype=np.float32),
            'pos_N': np.array(self.pos_N, dtype=np.float32),
            'pos_O': np.array(self.pos_O, dtype=np.float32),
        }

    def query_residues_radius(self, center, radius, criterion='center_of_mass'):
        center = np.array(center).reshape(3)
        selected = []
        for residue in self.residues:
            distance = np.linalg.norm(residue[criterion] - center, ord=2)
            print(residue[criterion], distance)
            if distance < radius:
                selected.append(residue)
        return selected

    def query_residues_ligand(self, ligand, radius, criterion='center_of_mass'):
        selected = []
        sel_idx = set()
        # The time-complexity is O(mn).
        for center in ligand['pos']:
            for i, residue in enumerate(self.residues):
                distance = np.linalg.norm(residue[criterion] - center, ord=2)
                if distance < radius and i not in sel_idx:
                    selected.append(residue)
                    sel_idx.add(i)
        return selected

    def residues_to_pdb_block(self, residues, name='POCKET'):
        block = "HEADER    %s\n" % name
        block += "COMPND    %s\n" % name
        for residue in residues:
            for atom_idx in residue['atoms']:
                block += self.atoms[atom_idx]['line'] + "\n"
        block += "END\n"
        return block

2.2.2 组合蛋白和小分子的字典 ProteinLigandData()

比较简单,合并蛋白和小分子的字典。

class ProteinLigandData(Data):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    @staticmethod
    def from_protein_ligand_dicts(protein_dict=None, ligand_dict=None, **kwargs):
        # 将蛋白和小分子的字典合并成为一个新的字典,
        # 蛋白的 key 前面增加 protein_
        # ligand 的 key 前面增加 ligand_ 
        instance = ProteinLigandData(**kwargs)

        if protein_dict is not None:
            for key, item in protein_dict.items():
                instance['protein_' + key] = item

        if ligand_dict is not None:
            for key, item in ligand_dict.items():
                instance['ligand_' + key] = item

        instance['ligand_nbh_list'] = {i.item(): [j.item() for k, j in enumerate(instance.ligand_bond_index[1])
                                                  if instance.ligand_bond_index[0, k].item() == i]
                                       for i in instance.ligand_bond_index[0]}
        return instance

    def __inc__(self, key, value, *args, **kwargs):
        if key == 'ligand_bond_index':
            return self['ligand_element'].size(0)
        else:
            return super().__inc__(key, value)

2.3 引导模型数据转换 GuideFeaturizeProteinAtom()

引导模型数据转化分为两块,分别对应蛋白和小分子的数据转换,对应类 GuideFeaturizeProteinAtom() 和 GuideFeaturizeLigandAtom() 。来源于:

from utils.transforms_prop import FeaturizeProteinAtom as GuideFeaturizeProteinAtom
from utils.transforms_prop import FollowerFeaturizeLigandAtom as GuideFeaturizeLigandAtom

整个引导模型的数据转换代码,如下(在 sample_for_pocket_guided 函数中):

# Guide Transforms 引导模型的数据转换器
guide_protein_featurizer = GuideFeaturizeProteinAtom()
guide_ligand_featurizer = GuideFeaturizeLigandAtom(mode=ckpt['config']['data']['transform']['ligand_atom_mode'])

2.3.1 引导模型的蛋白数据特征化

由 GuideFeaturizeProteinAtom() 类实现,主要功能为将 protein_data 字典组合成为节点特征,以及返回节点特征维度。

蛋白中原子的特征包括:元素(H, C, N, O, S, Se)、氨基酸类型、是否是主链。

代码如下:

 class FeaturizeProteinAtom(object):

    def __init__(self):
        super().__init__()
        self.atomic_numbers = torch.LongTensor([1, 6, 7, 8, 16, 34])    # H, C, N, O, S, Se
        self.max_num_aa = 20

    @property
    def feature_dim(self):
        # 节点特征维度
        return self.atomic_numbers.size(0) + self.max_num_aa + 1

    def __call__(self, data: ProteinLigandData):
        # 将 data(字典)组合成节点特征矩阵 
        element = data.protein_element.view(-1, 1) == self.atomic_numbers.view(1, -1)   # (N_atoms, N_elements)
        amino_acid = F.one_hot(data.protein_atom_to_aa_type, num_classes=self.max_num_aa)
        is_backbone = data.protein_is_backbone.view(-1, 1).long()
        x = torch.cat([element, amino_acid, is_backbone], dim=-1)
        data.protein_atom_feature = x
        return data

2.3.2 引导模型的小分子数据特征化

将小分子数据 ligand_data 特征化,由 dict 转化为节点特征矩阵,以及提取小分子特征矩阵的维度。

小分子可以兼容的元素为:H, C, N, O, F, P, S, Cl ,小分子的节点特征为:元素、原子芳香性。(注意,小分子和蛋白的节点特征不一致)

代码如下:

class FollowerFeaturizeLigandAtom(object):
    '''
    将小分子的 ligand_data 字典转化为特征矩阵
    '''
    
    def __init__(self, mode='basic'):
        super().__init__()
        self.atomic_numbers = torch.LongTensor([1, 6, 7, 8, 9, 15, 16, 17])  # H, C, N, O, F, P, S, Cl
        assert mode in ['basic', 'add_aromatic', 'full']
        self.mode = mode
        # self.n_degree = torch.LongTensor([0, 1, 2, 3, 4, 5])  # 0 - 5
        # self.n_num_hs = 6  # 0 - 5

    @property
    def feature_dim(self):
        if self.mode == 'basic':
            return len(MAP_ATOM_TYPE_ONLY_TO_INDEX)
        elif self.mode == 'add_aromatic':
            return len(MAP_ATOM_TYPE_AROMATIC_TO_INDEX)
        else:
            return len(MAP_ATOM_TYPE_FULL_TO_INDEX)

    def __call__(self, data: ProteinLigandData):
        element_list = data['ligand_element'] # 元素列表
        aromatic_list = data['ligand_atom_feature'][:,1] # 原子芳香性列表
        x = [get_index(e, None, a, self.mode) for e, a in zip(element_list, aromatic_list)]
        x = torch.tensor(x)
        x = F.one_hot(x, num_classes=self.feature_dim) # 节点特正 onr-hot 
        data.ligand_atom_feature_full = x # 小分子节点特征矩阵
        return data

2.4 加载梯度引导模型 get_guide_model()

在 sample_for_pocket_guided 函数中,加载梯度引导模型的相关代码是:(注:此部分为我们修改过后的代码)

guide_models = []
    for guide_model_config in config.guide_models:
        # Load Guide Checkpoint
        guide_ckpt = torch.load(guide_model_config.checkpoint, map_location=device)
        print(f"\nGuide Name: {guide_model_config.name}")
        logger.info(f"Guide Name: {guide_model_config.name}")
        print(f"\nGuide Training Config: {guide_ckpt['config']}")
        logger.info(f"Guide Training Config: {guide_ckpt['config']}")

        # Guide model
        guide_model = get_guide_model(guide_ckpt['config'], guide_protein_featurizer.feature_dim, guide_ligand_featurizer.feature_dim)
        guide_model.load_state_dict(guide_ckpt['model'])
        guide_model = guide_model.to(device)
        guide_model.eval()
        guide_models.append(guide_model)

通过 for 循环,get_guide_model 函数逐个加载梯度引导模型到梯度引导模型列表中 guide_models。

get_guide_model() 实际上对应 get_model() 函数, 来源于:

from scripts.property_prediction.inference import get_model as get_guide_model

get_model() 函数传入模型的类型(默认是 egnn ),蛋白节点特征维度、小分子特征维度,返回梯度引导模型 DockGuideNet3D。

代码如下:

def get_model(config, protein_atom_feat_dim, ligand_atom_feat_dim):
    '''
    实例化 梯度引导模型 DockGuideNet3D
    '''
    if config.model.model_type == 'egnn':
        ## the noisy guide model trained on diffused inputs
        model = DockGuideNet3D(
            config.model,
            protein_atom_feature_dim=protein_atom_feat_dim,
            ligand_atom_feature_dim=ligand_atom_feat_dim
        )
        return model
    else:
        raise NotImplementedError

2.5 引导分子采样 sample_guided_diffusion_ligand()

在完成口袋、主模型、梯度引导模型的加载,然后使用 sample_guided_diffusion_ligand 函数进行批次化的引导分子采样(包含了数据的批次化)。

sample_guided_diffusion_ligand 函数源自于 ./scipts/sample_multi_guided_diffusion.py 文件,其代码如下:

def sample_guided_diffusion_ligand(model, guide_models, guide_configs, data, num_samples, batch_size=1, device='cuda:0',
                            num_steps=None, pos_only=False, center_pos_mode='protein',
                            sample_num_atoms='prior'): # batch_size=16 改成 4
    '''
    逐个批次进行分子生成,每个批次内:
    1. 采样生成分子的原子数;
    2. 初始化小分子的位置和节点类型
    3. 使用主模型的 model.sample_multi_guided_diffusion,进行采样
    4. 生成分子坐标、节点类型、轨迹的解批次
    '''
    # 将主函数和梯度引导函数设置为评估模式
    model.eval()
    for guide_model in guide_models:
        guide_model.eval()

    # 多个分子的坐标和节点类型
    all_pred_pos, all_pred_v = [], []

    # 多个分子的生成轨迹(即:每个时间 t 下的坐标和节点类型)
    all_pred_pos_traj, all_pred_pos0_traj, all_pred_v_traj = [], [], []
    all_pred_v0_traj, all_pred_vt_traj = [], []

    time_list = []

    # 批次数
    num_batch = int(np.ceil(num_samples / batch_size))
    current_i = 0
    for i in tqdm(range(num_batch)): # 逐批次生成
        # 批次中生成的分子数
        n_data = batch_size if i < num_batch - 1 else num_samples - batch_size * (num_batch - 1)
        # 复制data, 组合成批次数据, 
        # FOLLOW_BATCH 特征在批处理中将保持单独的批次信息,确保合并后仍能区分原来属于哪个图结构的数据。
        # 这里跟踪的是('protein_element', 'ligand_element', 'ligand_bond_type')
        # 会产生一个 protein_element_batch 的张量形状为 (批次中节点数,),标记每一个节点属于哪一个图(分子)。
        batch = Batch.from_data_list([data.clone() for _ in range(n_data)], follow_batch=FOLLOW_BATCH).to(device)

        t1 = time.time()
        with torch.no_grad():
            batch_protein = batch.protein_element_batch

            # 各生成分子的原子数 batch_ligand
            if sample_num_atoms == 'prior':
                # 通过口袋形状,根据预定义好的口袋跨度-分子原子数分布中,采样各生成分子的原子数
                pocket_size = atom_num.get_space_size(batch.protein_pos.detach().cpu().numpy()) # 口袋跨度
                ligand_num_atoms = [atom_num.sample_atom_num(pocket_size).astype(int) for _ in range(n_data)] # 采样原子数
                batch_ligand = torch.repeat_interleave(torch.arange(n_data), torch.tensor(ligand_num_atoms)).to(device) # batch 口袋原子数
            elif sample_num_atoms == 'range':
                # 按照范围设定
                ligand_num_atoms = list(range(current_i + 1, current_i + n_data + 1))
                batch_ligand = torch.repeat_interleave(torch.arange(n_data), torch.tensor(ligand_num_atoms)).to(device)
            elif sample_num_atoms == 'ref':
                # 按照当前分子的原子数
                batch_ligand = batch.ligand_element_batch
                ligand_num_atoms = scatter_sum(torch.ones_like(batch_ligand), batch_ligand, dim=0).tolist()
            else:
                raise ValueError

            # init ligand pos,初始化小分子坐标
            # 使用 scatter_mean计算 batch.protein_pos 的加权均值
            # 按照 batch_protein 权重计算,舍弃小分子, 以及沿着批次维度计算),
            # batch_protein 和 batch_ligand 都是 mask,标记 batch 中哪里是蛋白,哪里是小分子 
            center_pos = scatter_mean(batch.protein_pos, batch_protein, dim=0)
            batch_center_pos = center_pos[batch_ligand]
            init_ligand_pos = batch_center_pos + torch.randn_like(batch_center_pos) # 蛋白质心加上随机值

            # init ligand v 初始化小分子节点类型,均为0初始化,随机概率初始化
            if pos_only:
                # 节点类型不初始化,均为0
                init_ligand_v = batch.ligand_atom_feature_full
            else:
                # 节点类型随机概率初始化
                uniform_logits = torch.zeros(len(batch_ligand), model.num_classes).to(device)
                init_ligand_v = log_sample_categorical(uniform_logits)

            # 主模型+引导模型进行分子生成
            r = model.sample_multi_guided_diffusion(
                guide_models=guide_models, # 引导模型
                guide_configs=guide_configs, # 引导模型配置
                n_data=n_data, # 批次中分子数
                device=device, 
                protein_pos=batch.protein_pos, # 蛋白坐标
                protein_v=batch.protein_atom_feature.float(), # 蛋白节点类型
                batch_protein=batch_protein, # 蛋白的 mask

                init_ligand_pos=init_ligand_pos, # 初始化小分子坐标
                init_ligand_v=init_ligand_v, # 初始化小分子 节点类型
                batch_ligand=batch_ligand, # 小分子 mask
                num_steps=num_steps, # 去噪/扩散步数
                pos_only=pos_only, 
                center_pos_mode=center_pos_mode # 中心模式, 蛋白
            )

            ligand_pos, ligand_v, ligand_pos_traj, ligand_v_traj = r['pos'], r['v'], r['pos_traj'], r['v_traj']
            ligand_v0_traj, ligand_vt_traj, ligand_pos0_traj = r['v0_traj'], r['vt_traj'], r['pos0_traj']
            
            # unbatch pos,坐标 ligand_pos 解批次化,并转为 numpy, 矩阵形状为 num_samples * [num_atoms_i, 3]
            ligand_cum_atoms = np.cumsum([0] + ligand_num_atoms)
            ligand_pos_array = ligand_pos.cpu().numpy().astype(np.float64)
            all_pred_pos += [ligand_pos_array[ligand_cum_atoms[k]:ligand_cum_atoms[k + 1]] for k in
                                range(n_data)]  # num_samples * [num_atoms_i, 3]

            # 坐标去噪轨迹 ligand_pos_traj 解批次
            all_step_pos = [[] for _ in range(n_data)]
            for p in ligand_pos_traj:  # step_i
                p_array = p.cpu().numpy().astype(np.float64)
                for k in range(n_data):
                    all_step_pos[k].append(p_array[ligand_cum_atoms[k]:ligand_cum_atoms[k + 1]])
            all_step_pos = [np.stack(step_pos) for step_pos in
                            all_step_pos]  # num_samples * [num_steps, num_atoms_i, 3]
            all_pred_pos_traj += [p for p in all_step_pos]
            
            # all_step_pos0 解批次
            all_step_pos0 = [[] for _ in range(n_data)]
            for p in ligand_pos0_traj:  # step_i
                p_array = p.cpu().numpy().astype(np.float64)
                for k in range(n_data):
                    all_step_pos0[k].append(p_array[ligand_cum_atoms[k]:ligand_cum_atoms[k + 1]])
            all_step_pos0 = [np.stack(step_pos) for step_pos in
                            all_step_pos0]  # num_samples * [num_steps, num_atoms_i, 3]
            all_pred_pos0_traj += [p for p in all_step_pos0]

            # unbatch v 节点类型 ligand_v 解批次
            ligand_v_array = ligand_v.cpu().numpy()
            all_pred_v += [ligand_v_array[ligand_cum_atoms[k]:ligand_cum_atoms[k + 1]] for k in range(n_data)]

            # 节点类型去噪轨迹 ligand_v_traj 解批次
            all_step_v = unbatch_v_traj(ligand_v_traj, n_data, ligand_cum_atoms)
            all_pred_v_traj += [v for v in all_step_v]

            if not pos_only:
                all_step_v0 = unbatch_v_traj(ligand_v0_traj, n_data, ligand_cum_atoms)
                all_pred_v0_traj += [v for v in all_step_v0]
                all_step_vt = unbatch_v_traj(ligand_vt_traj, n_data, ligand_cum_atoms)
                all_pred_vt_traj += [v for v in all_step_vt]
        t2 = time.time()
        time_list.append(t2 - t1)
        current_i += n_data
    return all_pred_pos, all_pred_v, all_pred_pos_traj, all_pred_v_traj, all_pred_v0_traj, all_pred_vt_traj, all_pred_pos0_traj, time_list

sample_guided_diffusion_ligand 函数的主要内容为:逐个批次进行分子生成,每个批次内,

    1. 采样生成分子的原子数;sample_num_atoms 参数决定生成分子的原子数来源, prior 根据预设好的口袋跨度-原子数关系采样;range 按照预设值的范围进行采样;ref 根据参考分子的原数量决定。

    2. 初始化小分子的位置和节点类型:坐标初始化为蛋白中心+随机数初始化;节点类型使用随机概率初始化或者不初始化(均为0)。

    3. 使用主模型的 model.sample_multi_guided_diffusion 函数,进行采样

    4. 生成分子坐标、节点类型、轨迹的解批次

其中,关键的 model.sample_multi_guided_diffusion 函数是,主模型+梯度引导模型,根据初始化的小分子坐标、节点类型以及蛋白的坐标和节点类型,进行分子生成。这个 sample_multi_guided_diffusion 函数是作者在 TargetDiff 模型 (ScorePosNet3D)上新增的函数,同样需新增的函数还有 sample_guided_diffusion 。除了这两个函数,原来 TargetDiff 的代码保持不变。

2.6 TagetDiff +多引导模型采样(ScorePosNet3D 中sample_multi_guided_diffusion)

主要由 ScorePosNet3D 模型的 sample_multi_guided_diffusion 方法实现。(注:由于 ScorePosNet3D 是一个完整的 3D 分子生成扩散模型,包含了训练损失,生成,以及其他添加噪音等一系列的函数,比较复杂。在这里仅仅对多种梯度模型引导的分子生成函数 sample_multi_guided_diffusion 及其支持函数进行解析。)

在 sample_multi_guided_diffusion 函数中,首先是基本准备,包括:检查梯度引导模型数量与配置数量是否相同;坐标中心归 0,调整小分子和蛋白的坐标;初始化 V_t, V_0, X_t, X_0, V_t-1, X_t-1等去噪过程记录列表;初始化时间 t 列表。代码为:

        # 检查 梯度引导模型数量与配置数量是否相同
        assert len(guide_models) == len(guide_configs), f"guide_models and guide_configs must have the same length"
        # 去噪步数
        if num_steps is None:
            num_steps = self.num_timesteps
        num_graphs = batch_protein.max().item() + 1

        # 整体中心归 0,调整小分子和蛋白的坐标
        protein_pos, init_ligand_pos, offset = center_pos(
            protein_pos, init_ligand_pos, batch_protein, batch_ligand, mode=center_pos_mode)

        pos_traj, v_traj = [], []
        v0_pred_traj, vt_pred_traj, pos0_traj = [], [], []
        ligand_pos, ligand_v = init_ligand_pos, init_ligand_v

        # time sequence 时间序列,从大到小
        time_seq = list(reversed(range(self.num_timesteps - num_steps, self.num_timesteps)))

然后进行逐步去噪过程。主要涉及:

(1)某个时刻 t 下,输入 V_t, X_t 到神经网络预测 V_0,X_0, 然后计算先验 V_t-1, X_t-1,提取真实的分子坐标即 X_0。 涉及代码:

        # 逐步去噪
        for i in tqdm(time_seq, desc='sampling', total=len(time_seq)):
            # 某 t 时刻
            t = torch.full(size=(num_graphs,), fill_value=i, dtype=torch.long, device=protein_pos.device)
            # 输入 V_t,神经网络预测 V_0
            with torch.no_grad():
                preds = self(
                    protein_pos=protein_pos,
                    protein_v=protein_v,
                    batch_protein=batch_protein,

                    init_ligand_pos=ligand_pos,
                    init_ligand_v=ligand_v,
                    batch_ligand=batch_ligand,
                    time_step=t
                )
            # Compute posterior mean and variance
            if self.model_mean_type == 'noise':
                # 如果神经网路输出的是噪音,提取真实分子坐标
                pred_pos_noise = preds['pred_ligand_pos'] - ligand_pos
                pos0_from_e = self._predict_x0_from_eps(xt=ligand_pos, eps=pred_pos_noise, t=t, batch=batch_ligand)
                v0_from_e = preds['pred_ligand_v']
                raise NotImplementedError
            elif self.model_mean_type == 'C0':
                # 如果神经网络输出的是真实分子坐标
                pos0_from_e = preds['pred_ligand_pos']
                v0_from_e = preds['pred_ligand_v']

            else:
                raise ValueError
            
            # 输入 V_t, X_t 和神经网络预测的 V_0, X_0 计算先验 V_t-1, X_t-1
            # 坐标
            pos_model_mean = self.q_pos_posterior(x0=pos0_from_e, xt=ligand_pos, t=t, batch=batch_ligand)
            # 坐标的对数方差,即不确定性
            pos_log_variance = extract(self.posterior_logvar, t, batch_ligand)

其中,关于坐标归 0 的函数 center_pos,逐个计算批次中每个蛋白-小分子体系的坐标均值,代码如下:

def center_pos(protein_pos, ligand_pos, batch_protein, batch_ligand, mode='protein'):
    if mode == 'none':
        # 不进行任何中心化操作,即坐标不归0
        offset = 0.
        pass
    elif mode == 'protein':
        # 计算 protein_pos 在每个批次的均值 offset。
        # scatter_mean 函数根据 batch_protein 张量逐个计算均值。
        offset = scatter_mean(protein_pos, batch_protein, dim=0) # 计算每个批次的蛋白中心位置
        protein_pos = protein_pos - offset[batch_protein] 
        ligand_pos = ligand_pos - offset[batch_ligand]
    else:
        raise NotImplementedError
    return protein_pos, ligand_pos, offset

(2) 输入 X_t-1, V_t-1 由梯度引导模型计算梯度。

注意,每一个引导模型都有自己的权重,以及梯度的缩放因子,即每个梯度引导模型产生的梯度都会受到 权重*梯度缩放因子的影响。对于多梯度引导模型来说,梯度权重,梯度缩放因子都是超参数,根据不同的模型需要不同设置。(显然,这是一个超参数问题,实际应用需要超参数检索???)

各个梯度引导模型计算梯度的函数是 get_gradients_guide()。将在梯度引导模型部分详细解析。

涉及代码:

            ## Classifier Guidance for the Diffusion 梯度引导
            ligand_pos_grad, ligand_v_grad = None, None
            # 逐个计算不同引导模型的梯度
            for guide_model, guide_config in zip(guide_models, guide_configs):
                guide_weight = guide_config.weight
                gradient_scale_cord = guide_config.gradient_scale_cord # 坐标梯度引导强度
                gradient_scale_categ = guide_config.gradient_scale_categ # 节点类型梯度引导强度
                clamp_pred_min = guide_config.get("clamp_pred_min", None) # 最小梯度值,默认 None
                clamp_pred_max = guide_config.get("clamp_pred_max", None) # 最大梯度值,默认为 None

                kind = guide_config.guide_kind # 梯度引导类型
                kind = torch.tensor([KMAP[kind]]*n_data).to(device) # 引导类型编号 {'Ki': 1, 'Kd': 2, 'IC50': 3}

                # 调用梯度引导模型计算梯度
                curr_ligand_pos_grad, curr_ligand_v_grad = guide_model.get_gradients_guide(
                    protein_pos=protein_pos,
                    protein_atom_feature=protein_v,
                    ligand_pos=ligand_pos,
                    ligand_atom_feature=F.one_hot(ligand_v,self.num_classes).float(), # 小分子的节点类型从概率转换为 noe-hot 
                    batch_protein=batch_protein,
                    batch_ligand=batch_ligand,
                    # output_kind=kind,
                    time_step=t,
                    pos_only=False,
                    clamp_pred_min=clamp_pred_min,
                    clamp_pred_max=clamp_pred_max,
                )

                # NOTE: extra terms to be applied later
                # 坐标类型的梯度
                if ligand_pos_grad is None:
                    # 第一个梯度引导模型,未有梯度积累 ligand_pos_grad 为 None
                    ligand_pos_grad = guide_weight * gradient_scale_cord * curr_ligand_pos_grad
                else:
                    ligand_pos_grad += (guide_weight * gradient_scale_cord * curr_ligand_pos_grad)
                
                # 节点类型的梯度
                if gradient_scale_categ != 0:
                    if ligand_v_grad is None:
                        ligand_v_grad = guide_weight * gradient_scale_categ * curr_ligand_v_grad
                    else:
                        ligand_v_grad += (guide_weight * gradient_scale_categ * curr_ligand_v_grad)

(3) 使用梯度更新坐标 X_t-1和节点类型 V_t-1。

需要注意的是,每一个坐标的梯度(移动程度)是根据坐标的确定性概率 (pos_log_variance) 进行了调整。节点特征部分,文章和代码中默认是不能使用梯度更新的,但是保留相应的代码。节点类型的梯度更新的是概率,而不是神经网络输入的对数概率。(神经网络输入和输出的节点类型都是对数概率)

涉及代码:

            ## update the coordinates based on the computed gradient after scaling  
            # 更新节点类型,按照不确定进行更新          
            ligand_pos_grad_update = ligand_pos_grad * ((0.5 * pos_log_variance).exp())
            # 更新坐标位置
            pos_model_mean = pos_model_mean - ligand_pos_grad_update

            # (注:按照超参数设置,节点类型是不更新的,因此文章设置了报错)
            # 节点类型的梯度只做了展示,但是没有更新到节点中
            assert ligand_v_grad is None, "Non-zero value for `gradient_scale_categ` is experimental and not part of the paper."
            
            ## end of classifier guidance
                
            # no noise when t == 0
            nonzero_mask = (1 - (t == 0).float())[batch_ligand].unsqueeze(-1) # ligand 的节点 mask, 如果 t=0则全部为0,即不更新
            # 随机添加噪音到坐标中
            ligand_pos_next = pos_model_mean + nonzero_mask * (0.5 * pos_log_variance).exp() * torch.randn_like(
                ligand_pos)
            # 将 X_t-1 设置为新的 X_t 准备下一个 t-1 时刻的去噪
            ligand_pos = ligand_pos_next

            if not pos_only:
                # 如果扩散模型不是只对坐标,节点类型也参与扩散,默认是 pos_only 为 False
                ## v0^ predicted from current time step, 神经网络预测的 V_0
                log_ligand_v_recon = F.log_softmax(v0_from_e, dim=-1)
                ## vt,V_t
                log_ligand_v = index_to_log_onehot(ligand_v, self.num_classes)                
                ## posterior probability of vt-1 given v0^ and vt
                # 计算 V_t-1, 是一个 log 概率
                log_model_prob = self.q_v_posterior(log_ligand_v_recon, log_ligand_v, t, batch_ligand)
                
                ## Classifer update
                ## update the categories based on gradient guidance
                # 如果节点类型由梯度,参数默认是没有梯度
                if ligand_v_grad is not None:
                    ## heuristic-driven. no significant mathematical justification 
                    # prob [0.8, 0.2]
                    # guidance [-0.9, -0.7]
                    # [-0.1, -0.5]
                    # [0.4, 0.0]
                    # [1, 0]
                    
                    # applying gradient scale & weights have been done after the get_gradients_guide() step itself
                    # 梯度是添加到概率中的,不是对数概率
                    updated_model_prob = torch.exp(log_model_prob) - ligand_v_grad
                    # 每一行的最小值归零,使得 updated_model_prob 中的所有值都非负,并保持每一行的相对差异不变
                    updated_model_prob = updated_model_prob - torch.min(updated_model_prob,axis=-1).values.unsqueeze(1)
                    # 每一行概率归一化
                    updated_model_prob = (updated_model_prob.T / updated_model_prob.sum(axis=1)).T
                    # 更新后的概率转化为 对数概率
                    log_model_prob = torch.log(updated_model_prob)
                ## end of classifier guidance
                
                ## sample vt-1 from the probabilities 根据节点对数概率采样节点概率(节点特征)
                ligand_v_next = log_sample_categorical(log_model_prob)
                # import pdb;pdb.set_trace()
                v0_pred_traj.append(log_ligand_v_recon.clone().cpu())
                vt_pred_traj.append(log_model_prob.clone().cpu())
                # 将 V_t-1 设置为新的 V_t 准备下一个 t-1 时刻的去噪
                ligand_v = ligand_v_next

(4) 记录过程

当然,要恢复小分子的中心

            # 记录过程
            ori_ligand_pos0 = pos0_from_e + offset[batch_ligand] # 还原小分子位置,把坐标中心加回来
            ori_ligand_pos = ligand_pos + offset[batch_ligand]
            pos0_traj.append(ori_ligand_pos0.clone().cpu())
            pos_traj.append(ori_ligand_pos.clone().cpu())
            v_traj.append(ligand_v.clone().cpu())

(5) 函数返回

经过全部 t 时刻的去噪,既可返回生成的小分子的坐标和节点类型。

ligand_pos = ligand_pos + offset[batch_ligand] # 还原小分子位置,把坐标中心加回来
        return {
            'pos': ligand_pos, # 生成的坐标
            'v': ligand_v, # 生成的节点类型
            'pos_traj': pos_traj, # 每个 t 时刻,X_t-1, 包含了梯度调整后的结果
            'pos0_traj': pos0_traj, # 每个 t 时刻,神经网络预测的 X_0
            'v_traj': v_traj, # 每个 t 时刻,V_t-1, 包含了梯度调整后的结果
            'v0_traj': v0_pred_traj, # 每个 t 时刻,神经网络预测的 V_0
            'vt_traj': vt_pred_traj # 每个 t 时刻,V_t
        }

(6) ScorePosNet3D 模型的 sample_multi_guided_diffusion 方法的完整代码:

    def sample_multi_guided_diffusion(self, guide_models, guide_configs, n_data, device, protein_pos, protein_v, batch_protein,
                         init_ligand_pos, init_ligand_v, batch_ligand,
                         num_steps=None, center_pos_mode=None, pos_only=False):
        # 检查 梯度引导模型数量与配置数量是否相同
        assert len(guide_models) == len(guide_configs), f"guide_models and guide_configs must have the same length"
        # 去噪步数
        if num_steps is None:
            num_steps = self.num_timesteps
        num_graphs = batch_protein.max().item() + 1

        # 中心归 0,调整小分子和蛋白的坐标
        protein_pos, init_ligand_pos, offset = center_pos(
            protein_pos, init_ligand_pos, batch_protein, batch_ligand, mode=center_pos_mode)

        pos_traj, v_traj = [], []
        v0_pred_traj, vt_pred_traj, pos0_traj = [], [], []
        ligand_pos, ligand_v = init_ligand_pos, init_ligand_v
        # time sequence 时间序列,从大到小
        time_seq = list(reversed(range(self.num_timesteps - num_steps, self.num_timesteps)))
        # 逐步去噪
        for i in tqdm(time_seq, desc='sampling', total=len(time_seq)):
            # 某 t 时刻
            t = torch.full(size=(num_graphs,), fill_value=i, dtype=torch.long, device=protein_pos.device)
            # 输入 V_t,神经网络预测 V_0
            with torch.no_grad():
                preds = self(
                    protein_pos=protein_pos,
                    protein_v=protein_v,
                    batch_protein=batch_protein,

                    init_ligand_pos=ligand_pos,
                    init_ligand_v=ligand_v,
                    batch_ligand=batch_ligand,
                    time_step=t
                )
            # Compute posterior mean and variance
            if self.model_mean_type == 'noise':
                # 如果神经网路输出的是噪音,提取真实分子坐标
                pred_pos_noise = preds['pred_ligand_pos'] - ligand_pos
                pos0_from_e = self._predict_x0_from_eps(xt=ligand_pos, eps=pred_pos_noise, t=t, batch=batch_ligand)
                v0_from_e = preds['pred_ligand_v']
                raise NotImplementedError
            elif self.model_mean_type == 'C0':
                # 如果神经网络输出的是真实分子坐标
                pos0_from_e = preds['pred_ligand_pos']
                v0_from_e = preds['pred_ligand_v']

            else:
                raise ValueError
            
            # 输入 V_t, X_t 和神经网络预测的 V_0, X_0 计算先验 V_t-1, X_t-1
            # 坐标
            pos_model_mean = self.q_pos_posterior(x0=pos0_from_e, xt=ligand_pos, t=t, batch=batch_ligand)
            # 坐标的对数方差,即不确定性
            pos_log_variance = extract(self.posterior_logvar, t, batch_ligand)
            
            ## Classifier Guidance for the Diffusion 梯度引导
            ligand_pos_grad, ligand_v_grad = None, None
            # 逐个计算不同引导模型的梯度
            for guide_model, guide_config in zip(guide_models, guide_configs):
                guide_weight = guide_config.weight
                gradient_scale_cord = guide_config.gradient_scale_cord # 坐标梯度引导强度
                gradient_scale_categ = guide_config.gradient_scale_categ # 节点类型梯度引导强度
                clamp_pred_min = guide_config.get("clamp_pred_min", None) # 最小梯度值,默认 None
                clamp_pred_max = guide_config.get("clamp_pred_max", None) # 最大梯度值,默认为 None

                kind = guide_config.guide_kind # 梯度引导类型
                kind = torch.tensor([KMAP[kind]]*n_data).to(device) # 引导类型编号 {'Ki': 1, 'Kd': 2, 'IC50': 3}

                # ligand_pos_grad = guide_model.get_gradients_guide(
                #     protein_pos=protein_pos,
                #     protein_atom_feature=protein_v,
                #     ligand_pos=ligand_pos,
                #     ligand_atom_feature=F.one_hot(ligand_v,self.num_classes).float(),
                #     batch_protein=batch_protein,
                #     batch_ligand=batch_ligand,
                #     output_kind=kind,
                #     pos_only=True
                # )

                # 调用梯度引导模型计算梯度
                curr_ligand_pos_grad, curr_ligand_v_grad = guide_model.get_gradients_guide(
                    protein_pos=protein_pos,
                    protein_atom_feature=protein_v,
                    ligand_pos=ligand_pos,
                    ligand_atom_feature=F.one_hot(ligand_v,self.num_classes).float(), # 小分子的节点类型从概率转换为 noe-hot 
                    batch_protein=batch_protein,
                    batch_ligand=batch_ligand,
                    # output_kind=kind,
                    time_step=t,
                    pos_only=False,
                    clamp_pred_min=clamp_pred_min,
                    clamp_pred_max=clamp_pred_max,
                )

                # NOTE: extra terms to be applied later
                # 坐标类型的梯度
                if ligand_pos_grad is None:
                    # 第一个梯度引导模型,未有梯度积累 ligand_pos_grad 为 None
                    ligand_pos_grad = guide_weight * gradient_scale_cord * curr_ligand_pos_grad
                else:
                    ligand_pos_grad += (guide_weight * gradient_scale_cord * curr_ligand_pos_grad)
                
                # 节点类型的梯度
                if gradient_scale_categ != 0:
                    if ligand_v_grad is None:
                        ligand_v_grad = guide_weight * gradient_scale_categ * curr_ligand_v_grad
                    else:
                        ligand_v_grad += (guide_weight * gradient_scale_categ * curr_ligand_v_grad)

            ## update the coordinates based on the computed gradient after scaling  
            # 更新节点类型,按照不确定进行更新          
            ligand_pos_grad_update = ligand_pos_grad * ((0.5 * pos_log_variance).exp())
            # 更新坐标位置
            pos_model_mean = pos_model_mean - ligand_pos_grad_update

            # (注:按照超参数设置,节点类型是不更新的,因此文章设置了报错)
            # 节点类型的梯度只做了展示,但是没有更新到节点中
            assert ligand_v_grad is None, "Non-zero value for `gradient_scale_categ` is experimental and not part of the paper."
            
            ## end of classifier guidance
                
            # no noise when t == 0
            nonzero_mask = (1 - (t == 0).float())[batch_ligand].unsqueeze(-1) # ligand 的节点 mask, 如果 t=0则全部为0,即不更新
            # 随机添加噪音到坐标中
            ligand_pos_next = pos_model_mean + nonzero_mask * (0.5 * pos_log_variance).exp() * torch.randn_like(
                ligand_pos)
            # 将 X_t-1 设置为新的 X_t 准备下一个 t-1 时刻的去噪
            ligand_pos = ligand_pos_next

            if not pos_only:
                # 如果扩散模型不是只对坐标,节点类型也参与扩散,默认是 pos_only 为 False
                ## v0^ predicted from current time step, 神经网络预测的 V_0
                log_ligand_v_recon = F.log_softmax(v0_from_e, dim=-1)
                ## vt,V_t
                log_ligand_v = index_to_log_onehot(ligand_v, self.num_classes)                
                ## posterior probability of vt-1 given v0^ and vt
                # 计算 V_t-1, 是一个 log 概率
                log_model_prob = self.q_v_posterior(log_ligand_v_recon, log_ligand_v, t, batch_ligand)
                
                ## Classifer update
                ## update the categories based on gradient guidance
                # 如果节点类型由梯度,参数默认是没有梯度
                if ligand_v_grad is not None:
                    ## heuristic-driven. no significant mathematical justification 
                    # prob [0.8, 0.2]
                    # guidance [-0.9, -0.7]
                    # [-0.1, -0.5]
                    # [0.4, 0.0]
                    # [1, 0]
                    
                    # applying gradient scale & weights have been done after the get_gradients_guide() step itself
                    # 梯度是添加到概率中的,不是对数概率
                    updated_model_prob = torch.exp(log_model_prob) - ligand_v_grad
                    # 每一行的最小值归零,使得 updated_model_prob 中的所有值都非负,并保持每一行的相对差异不变
                    updated_model_prob = updated_model_prob - torch.min(updated_model_prob,axis=-1).values.unsqueeze(1)
                    # 每一行概率归一化
                    updated_model_prob = (updated_model_prob.T / updated_model_prob.sum(axis=1)).T
                    # 更新后的概率转化为 对数概率
                    log_model_prob = torch.log(updated_model_prob)
                ## end of classifier guidance
                
                ## sample vt-1 from the probabilities 根据节点对数概率采样节点概率(节点特征)
                ligand_v_next = log_sample_categorical(log_model_prob)
                # import pdb;pdb.set_trace()
                v0_pred_traj.append(log_ligand_v_recon.clone().cpu())
                vt_pred_traj.append(log_model_prob.clone().cpu())
                # 将 V_t-1 设置为新的 V_t 准备下一个 t-1 时刻的去噪
                ligand_v = ligand_v_next

            # 记录过程
            ori_ligand_pos0 = pos0_from_e + offset[batch_ligand] # 还原小分子位置,把坐标中心加回来
            ori_ligand_pos = ligand_pos + offset[batch_ligand]
            pos0_traj.append(ori_ligand_pos0.clone().cpu())
            pos_traj.append(ori_ligand_pos.clone().cpu())
            v_traj.append(ligand_v.clone().cpu())

        ligand_pos = ligand_pos + offset[batch_ligand] # 还原小分子位置,把坐标中心加回来
        return {
            'pos': ligand_pos, # 生成的坐标
            'v': ligand_v, # 生成的节点类型
            'pos_traj': pos_traj, # 每个 t 时刻,X_t-1, 包含了梯度调整后的结果
            'pos0_traj': pos0_traj, # 每个 t 时刻,神经网络预测的 X_0
            'v_traj': v_traj, # 每个 t 时刻,V_t-1, 包含了梯度调整后的结果
            'v0_traj': v0_pred_traj, # 每个 t 时刻,神经网络预测的 V_0
            'vt_traj': vt_pred_traj # 每个 t 时刻,V_t
        }

三、训练梯度引导模型

虽然在分子生成时,可以同时使用多个梯度引导模型对 TargetDiff 的分子生成过程进行引导,但是在训练梯度引导模型时,是一个个单独训练的。如 2.1 部分描述的,每一个梯度引导模型,实际上都是一个 DockGuideNet3D() 模型,是基于不同的拟合值训练出来的。

训练梯度引导模型的命令是 (以 Binding Affinity 为例):

python scripts/train_dock_guide.py \
  configs/training_dock_guide.yml

关于训练梯度引导模型的配置文件,如下:

data:
  name: pl_dock_guide
  path: ./data/guide/crossdocked_v1.1_rmsd1.0_pocket10
  split: ./data/guide/crossdocked_pocket10_pose_split_dock_guide.pt
  index_path: ./data/guide/crossdocked_v1.1_rmsd1.0_pocket10/index.pkl
  transform:
    ligand_atom_mode: add_aromatic
    random_rot: False

model:
  model_mean_type: C0  # ['noise', 'C0']
  beta_schedule: sigmoid
  beta_start: 1.e-7
  beta_end: 2.e-3
  v_beta_schedule: cosine
  v_beta_s: 0.01
  num_diffusion_timesteps: 1000
  loss_v_weight: 100.
  sample_time_method: symmetric  # ['importance', 'symmetric']

  time_emb_dim: 4
  time_emb_mode: sin
  center_pos_mode: protein

  node_indicator: True
  model_type: egnn
  num_blocks: 1
  num_layers: 9
  hidden_dim: 128
  n_heads: 16
  edge_feat_dim: 4  # edge type feat
  num_r_gaussian: 20
  knn: 32 # !
  num_node_types: 8
  act_fn: relu
  norm: True
  cutoff_mode: knn  # [radius, none]
  ew_net_type: global  # [r, m, none]
  num_x2h: 1
  num_h2x: 1
  r_max: 10.
  x2h_out_fc: False
  sync_twoup: False
  update_x: False

train:
  seed: 2021
  batch_size: 16
  num_workers: 4
  n_acc_batch: 1
  max_iters: 200000
  val_freq: 2000
  pos_noise_std: 0.1
  max_grad_norm: 8.0
  bond_loss_weight: 1.0
  optimizer:
    type: adam
    lr: 0.001 #5.e-4
    weight_decay: 0
    beta1: 0.95
    beta2: 0.999
  scheduler:
    type: plateau
    factor: 0.6
    patience: 10
    min_lr: 1.e-6

其中,data 部分中的 name 为数据集的名字,会据此选择不同的数据加载模类; path 则为原始蛋白-小分子复合体系的保存位置,整个路径下的每一个文件夹是一个体系,默认是 ./data/guide/crossdocked_v1.1_rmsd1.0_pocket10。index_path 则是用于生成数据集的体系列表,是一个pkl文件,一般为 index.pkl。

在配置文件中,model 部分是模型的配置,其中包括了 噪音生成器的配置(如:model_mean_type,beta_schedule, num_diffusion_timesteps 等,注意,噪音调度器的配置要与主模型 TargetDiff 一致。 ),也包括了神经网络的配置(如: model_type, num_blocks 等)。train 部分则是一些训练梯度引导模型的超参数,例如:batch_size, seed, optimizer, scheduler 等。

3.1 __main__ 函数

整个__main__ 函数比较简单。这里就简单介绍一下。

(1) 加载参数,加载配置文件,创建日志路径,复制模型配置文件,复制代码。如下代码:

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('config', type=str)
    parser.add_argument('--device', type=str, default='cuda')
    parser.add_argument('--logdir', type=str, default='./logs') # 日志路径
    parser.add_argument('--tag', type=str, default='')
    parser.add_argument('--train_report_iter', type=int, default=200) # 输出间隔
    args = parser.parse_args()

    # Load configs 加载配置文件
    config = misc.load_config(args.config)
    config_name = os.path.basename(args.config)[:os.path.basename(args.config).rfind('.')]
    misc.seed_all(config.train.seed)

    # Logging,加载配置文件路径设置
    log_dir = misc.get_new_log_dir(args.logdir, prefix=config_name, tag=args.tag)
    ckpt_dir = os.path.join(log_dir, 'checkpoints') # checkpoint 保存路径
    os.makedirs(ckpt_dir, exist_ok=True)
    vis_dir = os.path.join(log_dir, 'vis')
    os.makedirs(vis_dir, exist_ok=True)
    logger = misc.get_logger('train', log_dir)
    writer = torch.utils.tensorboard.SummaryWriter(log_dir)
    logger.info(args)
    logger.info(config)
    shutil.copyfile(args.config, os.path.join(log_dir, os.path.basename(args.config))) # 复制配置文件
    shutil.copytree('./models', os.path.join(log_dir, 'models')) # 复制梯度引导模型代码

(2) 初始化数据转换器

trans.FeaturizeProteinAtom() 和trans.FeaturizeLigandAtom() 可以别分别将蛋白和小分子的字典数据,转化为tensor矩阵),trans.RandomRotation() 设置坐标旋转。将初始化后的 trans.RandomRotation() ,trans.FeaturizeProteinAtom() 和trans.FeaturizeLigandAtom() 组成的 transform 输入到 get_dataset() 函数中,加载数据,生成字典 dataset 和 subsets (训练集和验证集)。

代码如下:

    # Transforms 输入数据转换器
    protein_featurizer = trans.FeaturizeProteinAtom()
    ligand_featurizer = trans.FeaturizeLigandAtom(config.data.transform.ligand_atom_mode)
    transform_list = [
        protein_featurizer,
        ligand_featurizer,
        trans.FeaturizeLigandBond(),
    ]
    # 坐标随机旋转
    if config.data.transform.random_rot:
        transform_list.append(trans.RandomRotation())
    transform = Compose(transform_list)

    # Datasets and loaders 加载数据,对数据进行转化
    logger.info('Loading dataset...')
    # 从口袋 pdb 文件。小分子 sdf 文件到数据字典。
    dataset, subsets = get_dataset(
        config=config.data,
        transform=transform,
        # heavy_only=config.data.heavy_only
        index_path=config.data.index_path
    )
    train_set, val_set = subsets['train'], subsets['test']
    logger.info(f'Training: {len(train_set)} Validation: {len(val_set)}')

get_dataset() 函数来自于 datasets/__init__.py 文件。因为这里训练的是梯度引导模型,因此,调用的是 PocketLigandPairDockGuideDataset() 类实现数据加载。

代码如下:

def get_dataset(config, *args, **kwargs):
    name = config.name
    root = config.path
    if name == 'pl':
        dataset = PocketLigandPairDataset(root, *args, **kwargs)
    elif name == 'pdbbind':
        dataset = PDBBindDataset(root, *args, **kwargs)
    elif name == 'pl_dock_guide':
        # 梯度引导模型加载数据
        dataset = PocketLigandPairDockGuideDataset(root, *args, **kwargs)
    else:
        raise NotImplementedError('Unknown dataset: %s' % name)

    if 'split' in config:
        # 根据 ./data/guide/crossdocked_pocket10_pose_split_dock_guide.pt 划分训练集和验证集
        split = torch.load(config.split)
        subsets = {k: Subset(dataset, indices=v) for k, v in split.items()}
        return dataset, subsets
    else:
        return dataset

关于 PocketLigandPairDockGuideDatase() 类的详细解析。与很多基于 EGNN 或者 SE3 的神经网络相同,PocketLigandPairDockGuideDatase() 类非常相似。唯一的区别是在_process函数中将 vina score 也作为一个 data 的值,后期将用于训练梯度引导模型的标签(可以替换为任何值作为标签)。

PocketLigandPairDockGuideDatase() 类的代码如下:

class PocketLigandPairDockGuideDataset(Dataset):

    def __init__(self, raw_path, transform=None, version='final', index_path=None, processed_path=None):
        '''
        raw_path:原始数据的路径。
        transform:可选的转换函数,用于在数据加载时对数据进行处理。
        version:数据处理的版本,影响处理后的文件名。
        index_path:指向索引文件的路径,如果为 None,则从 raw_path 中自动寻找 index.pkl 文件。
        processed_path:处理后的数据存储路径,如果为 None,则根据 raw_path 和 version 自动生成。
        self.db:用于存储 LMDB 数据库的连接。
        self.keys:存储数据库中的键(用于索引数据)。
        
        如果处理后的数据文件不存在,调用 _process() 方法处理原始数据并生成处理后的数据文件
        '''
        super().__init__()
        self.raw_path = raw_path.rstrip('/')
        # 数据集体系名字列表文件
        self.index_path = os.path.join(self.raw_path, 'index.pkl') if index_path is None else index_path
        # 如果数据路径中没有 lmdb 文件,则生成 lmdb 文件名,准备重新生成
        if processed_path is None:
            self.processed_path = os.path.join(os.path.dirname(self.raw_path),
                                            os.path.basename(self.raw_path) + f'_processed_dock_guide_{version}.lmdb')

            # 打印输出数据集的文件名
            print(f'processed_path: {self.processed_path}')

        else:
            self.processed_path = processed_path
        self.transform = transform # 从字典到矩阵的转换器
        self.db = None

        self.keys = None

        if not os.path.exists(self.processed_path):
            # 如果处理后的数据文件 lmdb 文件不存在,则 self._process() 重新生成
            print(f'{self.processed_path} does not exist, begin processing data')
            self._process()

    def _connect_db(self):
        """
            用于建立只读的 LMDB 数据库连接 
            最大数据容量为 10 GB
            Establish read-only database connection
        """
        assert self.db is None, 'A connection has already been opened.'
        self.db = lmdb.open(
            self.processed_path,
            map_size=10*(1024*1024*1024),   # 10GB 最大数据容量
            create=False,
            subdir=False,
            readonly=True,
            lock=False,
            readahead=False,
            meminit=False,
        )
        with self.db.begin() as txn:
            self.keys = list(txn.cursor().iternext(values=False))

    def _close_db(self):
        self.db.close()
        self.db = None
        self.keys = None
        
    def _process(self):
        '''
        用于处理原始数据并将其存储到 LMDB 数据库中
        '''

        db = lmdb.open(
            self.processed_path,
            map_size=10*(1024*1024*1024),   # 10GB
            create=True,
            subdir=False,
            readonly=False,  # Writable
        )
        # 加载构建数据集的列表
        with open(self.index_path, 'rb') as f:
            index = pickle.load(f)

        num_skipped = 0  # 跳过的体系总数
        with db.begin(write=True, buffers=True) as txn:
            # 逐个体系生成一个 data 字典,传到 lmdb 中
            for i, (pocket_fn, ligand_fn, _, _, vina, props) in enumerate(tqdm(index)):
                if pocket_fn is None: continue
                try:
                    # data_prefix = '/data/work/jiaqi/binding_affinity'
                    data_prefix = self.raw_path
                    pocket_dict = PDBProtein(os.path.join(data_prefix, pocket_fn)).to_dict_atom() # 提取蛋白 pdb 文件到字典
                    ligand_dict = parse_sdf_file(os.path.join(data_prefix, ligand_fn)) # 提取小分子文件到字典
                    # 合并蛋白和小分子的字典到 data 字典
                    data = ProteinLigandData.from_protein_ligand_dicts(
                        protein_dict=torchify_dict(pocket_dict),
                        ligand_dict=torchify_dict(ligand_dict),
                    )
                    data.protein_filename = pocket_fn
                    data.ligand_filename = ligand_fn
                    data.vina_dock = vina # vina 打分,标签
                    data = data.to_dict()  # avoid torch_geometric version issue
                    data.update(props)
                    txn.put(
                        key=str(i).encode(),
                        value=pickle.dumps(data)
                    )
                except:
                    num_skipped += 1
                    print('Skipping (%d) %s' % (num_skipped, ligand_fn, ))
                    continue
        db.close()
    
    def __len__(self):
        if self.db is None:
            self._connect_db()
        return len(self.keys)

    def __getitem__(self, idx):
        data = self.get_ori_data(idx)
        if self.transform is not None:
            data = self.transform(data)
        return data

    def get_ori_data(self, idx):
        if self.db is None:
            self._connect_db()
        key = self.keys[idx]
        data = pickle.loads(self.db.begin().get(key))
        data = ProteinLigandData(**data)
        data.id = idx
        assert data.protein_pos.size(0) > 0
        return data

(5) 将训练集和测试集,加载到 dataloader 中(批次化)

代码简单,如下:

需要关注一下 FOLLOW_BATCH 参数, 此参数标记字典中,哪些数据 (key) 还需要记录批次中属于哪个图的 key,默认是 FOLLOW_BATCH = ('protein_element', 'ligand_element', 'ligand_bond_type',) 。

# follow_batch = ['protein_element', 'ligand_element']
    collate_exclude_keys = ['ligand_nbh_list']
    # 训练集迭代器 (批次化)
    train_iterator = utils_train.inf_iterator(DataLoader(
        train_set,
        batch_size=config.train.batch_size,
        shuffle=True,
        num_workers=config.train.num_workers,
        follow_batch=FOLLOW_BATCH, # 字典中,还需要记录批次中属于哪个图的 key
        exclude_keys=collate_exclude_keys # 不需要的 key
    ))
    # 验证集 dataloader(将矩阵合在一起批次化)
    val_loader = DataLoader(val_set, config.train.batch_size, shuffle=False,
                            follow_batch=FOLLOW_BATCH, exclude_keys=collate_exclude_keys)

(6) 加载梯度引导模型 DockGuideNet3D,训练优化器,梯度调节器。代码如下:

    # Model 梯度引导模型
    logger.info('Building model...')
    model = DockGuideNet3D(
        config.model,
        protein_atom_feature_dim=protein_featurizer.feature_dim,
        ligand_atom_feature_dim=ligand_featurizer.feature_dim
    ).to(args.device)
    # print(model)
    print(f'protein feature dim: {protein_featurizer.feature_dim} ligand feature dim: {ligand_featurizer.feature_dim}')
    logger.info(f'# trainable parameters: {misc.count_parameters(model) / 1e6:.4f} M')

    # Optimizer and scheduler 优化器和调度器
    optimizer = utils_train.get_optimizer(config.train.optimizer, model)
    scheduler = utils_train.get_scheduler(config.train.scheduler, optimizer)

(7) 批次训练函数,包括:调用梯度引导模型计算损失,梯度裁剪,记录log。代码如下:

    # 批次的训练,n_acc_batch 个批次更新一次权重
    def train(it):
        model.train()
        optimizer.zero_grad()
        avg_loss = 0
        for _ in range(config.train.n_acc_batch):
            batch = next(train_iterator).to(args.device)
            protein_noise = torch.randn_like(batch.protein_pos) * config.train.pos_noise_std
            gt_protein_pos = batch.protein_pos + protein_noise
            # 计算损失
            loss = model.get_loss(
                protein_pos=gt_protein_pos,
                protein_v=batch.protein_atom_feature.float(),
                batch_protein=batch.protein_element_batch,

                ligand_pos=batch.ligand_pos,
                ligand_v=batch.ligand_atom_feature_full,
                batch_ligand=batch.ligand_element_batch,
                dock=batch[config.train.get("target", "vina_dock")]        # TODO: show warning if target not in config
            )
            loss = loss / config.train.n_acc_batch
            avg_loss += loss
            loss.backward()
        # 梯度剪裁
        orig_grad_norm = clip_grad_norm_(model.parameters(), config.train.max_grad_norm)
        # 更新梯度
        optimizer.step()

        # log 批次间隔
        if it % args.train_report_iter == 0:
            logger.info(
                '[Train] Iter %d | Loss %.6f | Lr: %.6f | Grad Norm: %.6f' % (
                    it, avg_loss, optimizer.param_groups[0]['lr'], orig_grad_norm
                )
            )
            
            writer.add_scalar(f'train/loss', loss, it)
            writer.add_scalar('train/lr', optimizer.param_groups[0]['lr'], it)
            writer.add_scalar('train/grad', orig_grad_norm, it)
            writer.flush()

(7) 验证函数,与训练函数很相似,略过,代码如下:

    # 验证函数
    def validate(it):
        # fix time steps
        sum_loss, sum_n = 0, 0
        all_preds = []
        all_gts = []
        with torch.no_grad():
            model.eval()
            for batch in tqdm(val_loader, desc='Validate'):
                batch = batch.to(args.device)
                batch_size = batch.num_graphs
                for t in np.linspace(0, model.num_timesteps - 1, 10).astype(int):
                    time_step = torch.tensor([t] * batch_size).to(args.device)
                    loss = model.get_loss(
                        protein_pos=batch.protein_pos,
                        protein_v=batch.protein_atom_feature.float(),
                        batch_protein=batch.protein_element_batch,

                        ligand_pos=batch.ligand_pos,
                        ligand_v=batch.ligand_atom_feature_full,
                        batch_ligand=batch.ligand_element_batch,
                        time_step=time_step,
                        dock=batch[config.train.get("target", "vina_dock")]        # TODO: show warning if target not in config
                    )

                    sum_loss += float(loss) * batch_size
                    sum_n += batch_size

        avg_loss = sum_loss / sum_n
        # 根据平均损失更新调度器
        if config.train.scheduler.type == 'plateau':
            scheduler.step(avg_loss)
        elif config.train.scheduler.type == 'warmup_plateau':
            scheduler.step_ReduceLROnPlateau(avg_loss)
        else:
            scheduler.step()

        # 记录
        logger.info(
            '[Validate] Iter %05d | Loss %.6f' % (
                it, avg_loss 
            )
        )
        writer.add_scalar('val/loss', avg_loss, it)
        writer.flush()
        return avg_loss

(8) 在 try 语句中,尝试进行模型的训练、验证并每个间隔保存新的模型checkpoint。代码如下:

    try:
        best_loss, best_iter = None, None
        # 逐次迭代
        for it in range(1, config.train.max_iters + 1):
            # with torch.autograd.detect_anomaly():
            train(it) # 训练
            # 验证
            if it % config.train.val_freq == 0 or it == config.train.max_iters:
                val_loss = validate(it)
                if best_loss is None or val_loss < best_loss:
                    logger.info(f'[Validate] Best val loss achieved: {val_loss:.6f}')
                    best_loss, best_iter = val_loss, it
                    ckpt_path = os.path.join(ckpt_dir, '%d.pt' % it)
                    for old_ckpt in glob.glob(f"{ckpt_dir}/*.pt"):
                        os.remove(old_ckpt) 
                    # 保存模型 checkpoint
                    torch.save({
                        'config': config,
                        'model': model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'scheduler': scheduler.state_dict(),
                        'iteration': it,
                    }, ckpt_path)
                else:
                    logger.info(f'[Validate] Val loss is not improved. '
                                f'Best val loss: {best_loss:.6f} at iter {best_iter}')
    except KeyboardInterrupt:
        logger.info('Terminating...')

至此,整个训练梯度引导模型的部分就已经结束了。其中可以看出,剩下的关键就在于梯度引导模型 DockGuideNet3D 的定义了。由于 DockGuideNet3D 会使用主模型中的噪音生成器等,因此先介绍 主模型 TargetDiff ,即 ScorePosNet3D 类。

四、主模型 ScorePosNet3D 代码解析

梯度引导模型 ScorePosNet3D (也是 TargetDiff 模型)来源于 models.molopt_score_model。

(其实这一部分的内容也适合其他的分子生成扩散模型,这些模型的相似度很高)

4.1 关于扩散模型的原理简介

扩散模型分为正向(添加噪音)和逆向(去噪音)过程,如下图。

在正向过程中,有0~T,共 T 个时刻。对于某一个时刻 t-1,都有一个 \beta _t 代表添加的正态分布噪音的权重, 数据 X_{t-1} 被添加噪音到 X_t,这一过程的公式为:

q(x_{t} | x_{t-1}) = N(x_{t}; \sqrt{1-\beta_{t}} x_{t-1}, \beta_{t}I)

将 1-\beta _t 定义为 \alpha _t, 那么上述公式可写为:

q(x_{t} | x_{t-1}) = N(x_{t}; \sqrt{\alpha_{t}} x_{t-1}, \beta_{t}I)

通过迭代,以及整套分布噪音的性质,从 X_0  到 X_t  的过程可以使用如下公式直接计算得到:

q(x_{t} | x_{0}) = N(x_{t}; \sqrt{\alpha_{cumprod, t}} x_{0}, (1- \alpha_{cumprod, t})I) x_{t} = \sqrt{\alpha_{cumprod, t}} x_{0} + \sqrt{1-\alpha_{cumprod, t}}\epsilon

其中,\alpha_{cumprod, t} 为 从0~t 的 \alpha _t 的累乘,即:

\alpha_{cumprod, t} = \prod_{0}^{t} \alpha_{i} = \prod_{o}^{t}(1-\beta_{i})

\sqrt{\alpha_{cumprod, t}} 和 \sqrt{1-\alpha_{cumprod, t}} 分别代表由 X_0  到 X_t 过程中,X_0 和 正态分布噪音的系数。

去噪过程,也称之反向过程,即在 X_0 (真实或者神经网络预测的)和 t 时刻 X_t 的情况下,采样上一 t-1 时刻先验的 X_{t-1},其公式可以写为:

x_{t-1} = \frac{\sqrt{\alpha_{cumprod, t-1}} \beta_{t}}{1-\alpha_{cumprod,t}} x_{0} + \frac{\sqrt{\alpha_{t}}(1-\alpha_{cumprod,t-1})}{1-\alpha_{cumprod,t}} x_{t} + \sqrt{\beta_{t}} \epsilon

\frac{\sqrt{\alpha_{cumprod, t-1}} \beta_{t}}{1-\alpha_{cumprod,t}} 和 \frac{\sqrt{\alpha_{t}}(1-\alpha_{cumprod,t-1})}{1-\alpha_{cumprod,t}} 分别代表计算先验 X_{t-1} 时,X_0 和 X_t 的系数,\sqrt{\beta_{t}} 则是噪音协方差的系数。前两项,可以称为 X_{t-1} 的均值,最后一项可以称为X_{t-1} 的不确定性。在具体实践中,上述公式往往应用于坐标等连续变量。

但是在节点类型等分离变量时,使用贝叶斯推断,v_{t-1}的计算公式如下:

q(v_{t-1}|v_{t},v_{0}) = \frac{q(v_{t}|v_{t-1},v_{0}) \ast q(v_{t-1}|v_{0})}{q(v_{t}|v _{0})}

在实际操作中,可以使用归一化来替代 q(v_t | v_0 ),即:

q(v_{t-1}|v_{t},v_{0}) = \frac{q(v_{t}|v_{t-1},v_{0}) \ast q(v_{t-1}|v_{0})}{\sum_{ }^{ }{q(v_{t}|v_{t-1},v_{0}) \ast q(v_{t-1}|v_{0}) }}

注:离散变量的目标是,通过有限的类别或状态进行转移,模型关注的是离散状态的变化。连续变量的目标是处理图像、信号等连续数据的生成和去噪,模型关注如何通过时间序列恢复出原始连续数据。

注:不管是坐标还是节点类型,我们都是输入 0 时刻和 t 时刻的,V,X,预测 V_{t-1}X_{t-1}, 我们计算的都是后验。扩散模型计算的损失是真实的后验(使用真实的 X_0 和 V_0)和估计的后验(神经网络预测的 X_0V_0)分别计算的V_{t-1}X_{t-1} 之间的 KL 损失。节点类型和坐标在计算后验时,方法有所不同。

关于扩散模型的损失,在真实的扩散(向前)过程的逆过程 p(x_{t-1} | x_{t},x_{0} ) 是已知的,这一过程是先验。我们会让神经网络 \theta 来预测 X_0,即 x(\theta)_{0}, 然后,通过上述公式,计算 x(\theta)_{t-1}(这一过程为后验),即 p(x(\theta)_{t-1} | x_{t},x(\theta)_{0} ),然后计算 后验的 x(\theta)_{t-1} 和 先验的 x_{t-1} 的 KL 散度,做为神经网络的损失。因此,神经网络可以学习整个扩散/去噪过程中数据的分布。(注,实际上,在非分子生成领域,神经网络的损失是噪音的MSE,只有在最后几步 t(例如:1~5) 才是 KL 散度)

综上,为了训练一个扩散模型,有几个非常重要的参数,需要预先计算出来(往往在 __inti__ 函数中计算):

(1) \alpha _t 和 \beta _t,分别代表在一步(由 t 到 t+1)添加噪音过程中,X_{t-1} 和 噪音 \epsilon 的 ”系数“,以 \beta _t 为基础,\beta _t  直接定义为一个列表;

(2)\sqrt{\alpha_{cumprod, t}} 和 \sqrt{1-\alpha_{cumprod, t}} 分别代表由 X_0 到 X_t 过程中,X_0 和 正态分布噪音 \epsilon 的系数,用于添加噪音的过程;

(3)\frac{\sqrt{\alpha_{cumprod, t-1}} \beta_{t}}{1-\alpha_{cumprod,t}} 和 \frac{\sqrt{\alpha_{t}}(1-\alpha_{cumprod,t-1})}{1-\alpha_{cumprod,t}} 分别代表计算后验/先验的 X_{t-1} 时,X_0X_t 的系数,\sqrt{\beta_{t}} 则是噪音协方差的系数,用于计算损失;

4.2 主模型 ScorePosNet3D 代码解析

结合上述 3.1 过程中,对 ScorePosNet3D 其中的方法,逐个介绍。

作为一个分子生成扩散模型,具有两个重要的功能,一个是计算损失(训练)给优化器,对应的是 get_diffusion_loss() 方法;另一个是分子生成,对应的是 TargetDiff 自己的分子生成 sample_diffusion 函数。

此外,TagMol 也对 TargetDiff 的分子生成过程做了修改,以便实现梯度引导分子生成,在 ScorePosNet3D 还存在一个sample_guided_diffusion (或者 sample_multi_guided_diffusion ) 函数。

接下来,将按照初始化函数 __inti__、训练函数 get_diffusion_loss、分子生函数 sample_diffusion 的顺序对 ScorePosNet3D 进行全面介绍。梯度引导分子生成函数 sample_multi_guided_diffusion 已经在 2.6 部分进行了详细介绍,这里不重复。

4.2.1 __inti__ 函数

__inti__函数里面主要是预先设置了扩散模型中的参数,神经网络的超参数。

首先,(1) 基本设置:

    def __init__(self, config, protein_atom_feature_dim, ligand_atom_feature_dim):
        super().__init__()
        '''
        模型中神经网络的配置
        扩散相关的超参数系数
        '''

        self.config = config

        # variance schedule
        # 神经网络输出的目标,noise 代表噪音, C0 代表真实值
        self.model_mean_type = config.model_mean_type  # ['noise', 'C0']
        # 损失权重??
        self.loss_v_weight = config.loss_v_weight

(2) 扩散模型的超参数设置:

首先是时间t 的列表

        # 时间 t 的采样方法
        self.sample_time_method = config.sample_time_method  # ['importance', 'symmetric']

然后定义 β_t 和 α_t 列表:

        # 创建 β_t 和 α_t 列表
        if config.beta_schedule == 'cosine':
            alphas = cosine_beta_schedule(config.num_diffusion_timesteps, config.pos_beta_s) ** 2
            # print('cosine pos alpha schedule applied!')
            betas = 1. - alphas
        else:
            betas = get_beta_schedule(
                beta_schedule=config.beta_schedule,
                beta_start=config.beta_start,
                beta_end=config.beta_end,
                num_diffusion_timesteps=config.num_diffusion_timesteps,
            )
            alphas = 1. - betas

其中, cosine_beta_schedule 定了一个 cosine 噪音调度器 ,也是很常用的噪音调度器,先算β_t, 然后计算α_t。代码如下:

def cosine_beta_schedule(timesteps, s=0.008):
    """
    cosine schedule 噪音调度器 
    0~T 的 cos(1/t * 0.5 * Π),t 时刻的后累乘 除以 t前时刻的累乘, 然后再开根号 
    as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
    """
    steps = timesteps + 1
    x = np.linspace(0, steps, steps)
    alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    alphas = (alphas_cumprod[1:] / alphas_cumprod[:-1])

    alphas = np.clip(alphas, a_min=0.001, a_max=1.)

    # Use sqrt of this, so the alpha in our paper is the alpha_sqrt from the
    # Gaussian diffusion in Ho et al.
    alphas = np.sqrt(alphas)
    return alphas

另外一个噪音调度器,先算α_t,然后计算 β_t ,代码如下:

def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
    def sigmoid(x):
        return 1 / (np.exp(-x) + 1)

    if beta_schedule == "quad":
        betas = (
                np.linspace(
                    beta_start ** 0.5,
                    beta_end ** 0.5,
                    num_diffusion_timesteps,
                    dtype=np.float64,
                )
                ** 2
        )
    elif beta_schedule == "linear":
        betas = np.linspace(
            beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
        )
    elif beta_schedule == "const":
        betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
    elif beta_schedule == "jsd":  # 1/T, 1/(T-1), 1/(T-2), ..., 1
        betas = 1.0 / np.linspace(
            num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64
        )
    elif beta_schedule == "sigmoid":
        betas = np.linspace(-6, 6, num_diffusion_timesteps)
        betas = sigmoid(betas) * (beta_end - beta_start) + beta_start
    else:
        raise NotImplementedError(beta_schedule)
    assert betas.shape == (num_diffusion_timesteps,)
    return betas

不同的噪音调度器,意味着不同的信息的衰减曲线。

计算添加噪音过程的 \sqrt{\alpha_{cumprod, t}}  和 \sqrt{1-\alpha_{cumprod, t}} 分别对应下述代码中的 self.sqrt_alphas_cumprod 和 self.sqrt_one_minus_alphas_cumprod,代表由 X_0 到 X_t 过程中,X_0 和 正态分布噪音的系数,用于添加噪音的过程;

        # α_t 的累乘
        alphas_cumprod = np.cumprod(alphas, axis=0)
        alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])

        # tensor 矩阵化
        self.betas = to_torch_const(betas)
        self.num_timesteps = self.betas.size(0)
        self.alphas_cumprod = to_torch_const(alphas_cumprod)
        self.alphas_cumprod_prev = to_torch_const(alphas_cumprod_prev)

        # calculations for diffusion q(x_t | x_{t-1}) and others 添加噪音过程的超参数
        # 添加噪音过程中 X_0 的比例
        self.sqrt_alphas_cumprod = to_torch_const(np.sqrt(alphas_cumprod)) 
        # 添加噪音过程中 噪音 的比例
        self.sqrt_one_minus_alphas_cumprod = to_torch_const(np.sqrt(1. - alphas_cumprod))
        self.sqrt_recip_alphas_cumprod = to_torch_const(np.sqrt(1. / alphas_cumprod))
        self.sqrt_recipm1_alphas_cumprod = to_torch_const(np.sqrt(1. / alphas_cumprod - 1))

然后是计算先验/后验的,由 X_t 和 X_0 计算 X_t-1 的超参数, \frac{\sqrt{\alpha_{cumprod, t-1}} \beta_{t}}{1-\alpha_{cumprod,t}} 和 \frac{\sqrt{\alpha_{t}}(1-\alpha_{cumprod,t-1})}{1-\alpha_{cumprod,t}} ,分别对应下面代码的 self.posterior_mean_c0_coef 和 

self.posterior_mean_ct_coef, 代表 X_0 和 X_t  的系数;\sqrt{\beta_{t}} 则于代码中的 posterior_variance 对应,噪音协方差的系数。

        # calculations for posterior q(x_{t-1} | x_t, x_0) 计算先验 X_t-1 的 超参数
        # 给定当前状态 x_t 和 x_0​ 的情况下,前一个时间步 x_{t-1} 的方差。这些方差用于描述从 x_t​ 到 x_{t-1}​ 过程中引入的噪声量。
        posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
        # X_0 的系数
        self.posterior_mean_c0_coef = to_torch_const(betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
        # X_t 的系数
        self.posterior_mean_ct_coef = to_torch_const(
            (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))
        # log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
        
        self.posterior_var = to_torch_const(posterior_variance) # 后验方差
        self.posterior_logvar = to_torch_const(np.log(np.append(self.posterior_var[1], self.posterior_var[1:]))) # 后验对数方差

对于节点特征而言,因为节点特征是 one-hot 的,为了数值稳定,需要将噪音添加在对数空间上操作,那么就需要计算对数空间的 \alpha _t\beta _t, 以及噪音添加过程中的其他参数。节点特征的 log(\sqrt{\alpha_{cumprod, t}}) 和 log(\sqrt{1-\alpha_{cumprod, t}}) 分别对应下述代码中的 self.log_alphas_cumprod_v 和 self.log_one_minus_alphas_cumprod_v

        # atom type diffusion schedule in log space,原子类型的扩散在对数空间实现,为了数值稳定
        # 计算 α_t
        if config.v_beta_schedule == 'cosine':
            alphas_v = cosine_beta_schedule(self.num_timesteps, config.v_beta_s)
            # print('cosine v alpha schedule applied!')
        else:
            raise NotImplementedError
        # 计算 log(α_t)
        log_alphas_v = np.log(alphas_v)
        # 计算 log(α_t * α_t-1 .... α_0), 相当于α_t 的累积
        log_alphas_cumprod_v = np.cumsum(log_alphas_v)
        self.log_alphas_v = to_torch_const(log_alphas_v)
        # 计算 log(1-α_t)
        self.log_one_minus_alphas_v = to_torch_const(log_1_min_a(log_alphas_v))
        self.log_alphas_cumprod_v = to_torch_const(log_alphas_cumprod_v)
        # 计算 log((1-α_t) * (1-α_t-1) .... (1-α_0)), 相当于1-α_t 的累积
        self.log_one_minus_alphas_cumprod_v = to_torch_const(log_1_min_a(log_alphas_cumprod_v))

(3)神经网络的超参数配置

小分子的输入节点向量维度,以及隐藏层向量维度,蛋白节点的输入向量维度

        # model definition 神经网络模型定义
        self.hidden_dim = config.hidden_dim # 隐藏层维度
        self.num_classes = ligand_atom_feature_dim # 小分子的输入节点向量维度
        if self.config.node_indicator: # ????, 参数是 True
            emb_dim = self.hidden_dim - 1
        else:
            emb_dim = self.hidden_dim
        # atom embedding 蛋白原子的嵌入层
        self.protein_atom_emb = nn.Linear(protein_atom_feature_dim, emb_dim)

然后是,时间t 和小分子节点的嵌入层的设置。从代码中可以推测,时间t 的嵌入以后和小分子的嵌入合并在一起,二者形成了隐藏层的维度。

        # time embedding 时间和小分子的原子嵌入
        self.time_emb_dim = config.time_emb_dim
        # 时间 t 的嵌入模式
        self.time_emb_mode = config.time_emb_mode  # ['simple', 'sin']
        if self.time_emb_dim > 0:
            # 简单嵌入
            if self.time_emb_mode == 'simple':
                self.ligand_atom_emb = nn.Linear(ligand_atom_feature_dim + 1, emb_dim)
            # sin 嵌入 
            elif self.time_emb_mode == 'sin':
                self.time_emb = nn.Sequential(
                    SinusoidalPosEmb(self.time_emb_dim),
                    nn.Linear(self.time_emb_dim, self.time_emb_dim * 4),
                    nn.GELU(),
                    nn.Linear(self.time_emb_dim * 4, self.time_emb_dim)
                )
                self.ligand_atom_emb = nn.Linear(ligand_atom_feature_dim + self.time_emb_dim, emb_dim)
            else:
                raise NotImplementedError
        else:
            self.ligand_atom_emb = nn.Linear(ligand_atom_feature_dim, emb_dim)

然后就是神经网络的初始化

        self.refine_net_type = config.model_type # 神经网络,uni_o2,不是ENGG
        self.refine_net = get_refine_net(self.refine_net_type, config)
        # 节点特征网络
        self.v_inference = nn.Sequential(
            nn.Linear(self.hidden_dim, self.hidden_dim),
            ShiftedSoftplus(),
            nn.Linear(self.hidden_dim, ligand_atom_feature_dim),
        )

get_refine_net 初始化神经网络,结合项目的配置文件,知道神网络阔使用的是 uni_o2,不是 ENGG。self.v_inference 用于从 uni_o2 的输出节点嵌入中,预测节点类型概率。

完整的__inti__代码如下:

    def __init__(self, config, protein_atom_feature_dim, ligand_atom_feature_dim):
        super().__init__()
        '''
        模型中神经网络的配置
        扩散相关的超参数系数
        '''

        self.config = config

        # variance schedule
        # 神经网络输出的目标,noise 代表噪音, C0 代表真实值
        self.model_mean_type = config.model_mean_type  # ['noise', 'C0']
        # 损失权重??
        self.loss_v_weight = config.loss_v_weight
        # self.v_mode = config.v_mode
        # assert self.v_mode == 'categorical'
        # self.v_net_type = getattr(config, 'v_net_type', 'mlp')
        # self.bond_loss = getattr(config, 'bond_loss', False)
        # self.bond_net_type = getattr(config, 'bond_net_type', 'pre_att')
        # self.loss_bond_weight = getattr(config, 'loss_bond_weight', 0.)
        # self.loss_non_bond_weight = getattr(config, 'loss_non_bond_weight', 0.)

        # 时间 t 的采样方法
        self.sample_time_method = config.sample_time_method  # ['importance', 'symmetric']
        # self.loss_pos_type = config.loss_pos_type  # ['mse', 'kl']
        # print(f'Loss pos mode {self.loss_pos_type} applied!')
        # print(f'Loss bond net type: {self.bond_net_type} '
        #       f'bond weight: {self.loss_bond_weight} non bond weight: {self.loss_non_bond_weight}')

        # 创建 β_t 和 α_t 列表
        if config.beta_schedule == 'cosine':
            alphas = cosine_beta_schedule(config.num_diffusion_timesteps, config.pos_beta_s) ** 2
            # print('cosine pos alpha schedule applied!')
            betas = 1. - alphas
        else:
            betas = get_beta_schedule(
                beta_schedule=config.beta_schedule,
                beta_start=config.beta_start,
                beta_end=config.beta_end,
                num_diffusion_timesteps=config.num_diffusion_timesteps,
            )
            alphas = 1. - betas
        # α_t 的累乘
        alphas_cumprod = np.cumprod(alphas, axis=0)
        alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])

        # tensor 矩阵化
        self.betas = to_torch_const(betas)
        self.num_timesteps = self.betas.size(0)
        self.alphas_cumprod = to_torch_const(alphas_cumprod)
        self.alphas_cumprod_prev = to_torch_const(alphas_cumprod_prev)

        # calculations for diffusion q(x_t | x_{t-1}) and others 添加噪音过程的超参数
        # 添加噪音过程中 X_0 的比例
        self.sqrt_alphas_cumprod = to_torch_const(np.sqrt(alphas_cumprod)) 
        # 添加噪音过程中 噪音 的比例
        self.sqrt_one_minus_alphas_cumprod = to_torch_const(np.sqrt(1. - alphas_cumprod))
        self.sqrt_recip_alphas_cumprod = to_torch_const(np.sqrt(1. / alphas_cumprod))
        self.sqrt_recipm1_alphas_cumprod = to_torch_const(np.sqrt(1. / alphas_cumprod - 1))

        # calculations for posterior q(x_{t-1} | x_t, x_0) 计算先验 X_t-1 的 超参数
        # 给定当前状态 x_t 和 x_0​ 的情况下,前一个时间步 x_{t-1} 的方差。这些方差用于描述从 x_t​ 到 x_{t-1}​ 过程中引入的噪声量。
        posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
        # X_0 的系数
        self.posterior_mean_c0_coef = to_torch_const(betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
        # X_t 的系数
        self.posterior_mean_ct_coef = to_torch_const(
            (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))
        # log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
        
        self.posterior_var = to_torch_const(posterior_variance) # 后验方差
        self.posterior_logvar = to_torch_const(np.log(np.append(self.posterior_var[1], self.posterior_var[1:]))) # 后验对数方差

        # atom type diffusion schedule in log space,原子类型的扩散在对数空间实现,为了数值稳定
        # 计算 α_t
        if config.v_beta_schedule == 'cosine':
            alphas_v = cosine_beta_schedule(self.num_timesteps, config.v_beta_s)
            # print('cosine v alpha schedule applied!')
        else:
            raise NotImplementedError
        # 计算 log(α_t)
        log_alphas_v = np.log(alphas_v)
        # 计算 log(α_t * α_t-1 .... α_0), 相当于α_t 的累积
        log_alphas_cumprod_v = np.cumsum(log_alphas_v)
        self.log_alphas_v = to_torch_const(log_alphas_v)
        # 计算 log(1-α_t)
        self.log_one_minus_alphas_v = to_torch_const(log_1_min_a(log_alphas_v))
        self.log_alphas_cumprod_v = to_torch_const(log_alphas_cumprod_v)
        # 计算 log((1-α_t) * (1-α_t-1) .... (1-α_0)), 相当于1-α_t 的累积
        self.log_one_minus_alphas_cumprod_v = to_torch_const(log_1_min_a(log_alphas_cumprod_v))

        self.register_buffer('Lt_history', torch.zeros(self.num_timesteps))
        self.register_buffer('Lt_count', torch.zeros(self.num_timesteps))

        # model definition 神经网络模型定义
        self.hidden_dim = config.hidden_dim # 隐藏层维度
        self.num_classes = ligand_atom_feature_dim # 小分子的输入节点向量维度
        if self.config.node_indicator: # ????, 参数是 True
            emb_dim = self.hidden_dim - 1
        else:
            emb_dim = self.hidden_dim

        # atom embedding 蛋白原子的嵌入层
        self.protein_atom_emb = nn.Linear(protein_atom_feature_dim, emb_dim)

        # center pos 中心模式
        self.center_pos_mode = config.center_pos_mode  # ['none', 'protein']

        # time embedding 时间和小分子的原子嵌入
        self.time_emb_dim = config.time_emb_dim
        # 时间 t 的嵌入模式
        self.time_emb_mode = config.time_emb_mode  # ['simple', 'sin']
        if self.time_emb_dim > 0:
            # 简单嵌入
            if self.time_emb_mode == 'simple':
                self.ligand_atom_emb = nn.Linear(ligand_atom_feature_dim + 1, emb_dim)
            # sin 嵌入 
            elif self.time_emb_mode == 'sin':
                self.time_emb = nn.Sequential(
                    SinusoidalPosEmb(self.time_emb_dim),
                    nn.Linear(self.time_emb_dim, self.time_emb_dim * 4),
                    nn.GELU(),
                    nn.Linear(self.time_emb_dim * 4, self.time_emb_dim)
                )
                self.ligand_atom_emb = nn.Linear(ligand_atom_feature_dim + self.time_emb_dim, emb_dim)
            else:
                raise NotImplementedError
        else:
            self.ligand_atom_emb = nn.Linear(ligand_atom_feature_dim, emb_dim)

        self.refine_net_type = config.model_type # 神经网络,uni_o2,不是ENGG
        self.refine_net = get_refine_net(self.refine_net_type, config)
        # 节点特征网络
        self.v_inference = nn.Sequential(
            nn.Linear(self.hidden_dim, self.hidden_dim),
            ShiftedSoftplus(),
            nn.Linear(self.hidden_dim, ligand_atom_feature_dim),
        )

4.2.2 get_diffusion_loss 函数

get_diffusion_loss 函数输入蛋白坐标、蛋白节点类型、蛋白图 mask、小分子坐标、小分子节点类型、小分子图 mask 、以及时间步 t (默认为 None),计算扩散模型的损失。

(1) 计算批次中图的数量,去坐标中心

        num_graphs = batch_protein.max().item() + 1 # 蛋白/小分子图数量
        protein_pos, ligand_pos, _ = center_pos(
            protein_pos, ligand_pos, batch_protein, batch_ligand, mode=self.center_pos_mode) # 坐标中心归0

(2) 采样时间 t,以及时间 t 出现的概率

        if time_step is None:
            # 采样
            time_step, pt = self.sample_time(num_graphs, protein_pos.device, self.sample_time_method)
        else:
            # 按参数指定
            pt = torch.ones_like(time_step).float() / self.num_timesteps

关于 self.sample_time 函数,按照权重或者对称随机采样与图数量一致的时间步 t ,代码及注释如下。

    def sample_time(self, num_graphs, device, method):
        if method == 'importance':
            # 使用基于权重的重要性采样方法
            if not (self.Lt_count > 10).all():
                return self.sample_time(num_graphs, device, method='symmetric')

            Lt_sqrt = torch.sqrt(self.Lt_history + 1e-10) + 0.0001
            Lt_sqrt[0] = Lt_sqrt[1]  # Overwrite decoder term with L1.
            pt_a照分布概率多项式采样 时间步
            time_step = torch.multinomial(pt_all, num_samples=num_graphs, replacement=True)
            # 对应时间步的概率
            pt = pt_all.gather(dim=0, index=time_step)
            return time_step, pt

        elif method == 'symmetric':
            # 对称采样
            # 随机抽取 图一半数量的 t 
            time_step = torch.randint(
                0, self.num_timesteps, size=(num_graphs // 2 + 1,), device=device)
            # 对称生成另一半数量的时间步 t 
            time_step = torch.cat(
                [time_step, self.num_timesteps - time_step - 1], dim=0)[:num_graphs]
            # 概率
            pt = torch.ones_like(time_step).float() / self.num_timesteps
            return time_step, ptll = Lt_sqrt / Lt_sqrt.sum() # 分布概率

            # 按

        else:
            raise ValueError

(3) 提取时间步 t 对应的 \alpha_{cumprod, t} :

a = self.alphas_cumprod.index_select(0, time_step)  # (num_graphs, )

(4) 为小分子节点特征以及坐标添加噪音,分别生成添加噪音后的 ligand_v_perturbed, ligand_pos_perturbed。

        # 扩展 α_cumprod,t 的维度与分子节点数相同
        a_pos = a[batch_ligand].unsqueeze(-1)  # (num_ligand_atoms, 1) 
        # 分子的噪音
        pos_noise = torch.zeros_like(ligand_pos) 
        pos_noise.normal_()
        # 扰动小分子的坐标,公式:Xt = a.sqrt() * X0 + (1-a).sqrt() * eps
        ligand_pos_perturbed = a_pos.sqrt() * ligand_pos + (1.0 - a_pos).sqrt() * pos_noise  # pos_noise * std
        # 节点类型 one-hot 的 log 概率,然后添加噪音扰动, 公式:Vt = a * V0 + (1-a) / K ,
        log_ligand_v0 = index_to_log_onehot(ligand_v, self.num_classes)
        ligand_v_perturbed, log_ligand_vt = self.q_v_sample(log_ligand_v0, time_step, batch_ligand)

其中,坐标和节点特征添加噪音对应上文公式:

x_{t} = \sqrt{\alpha_{cumprod, t}} x_{0} + \sqrt{1-\alpha_{cumprod, t}}\epsilon

关于 index_to_log_onehot 函数,给定的节点特征转换为 one-hot 编码的形式,并返回其对数表示(概率),代码如下:

def index_to_log_onehot(x, num_classes):
    '''
    定的节点特征转换为 one-hot 编码的形式,并返回其对数表示(概率)
    '''
    assert x.max().item() < num_classes, f'Error: {x.max().item()} >= {num_classes}'
    x_onehot = F.one_hot(x, num_classes)
    # permute_order = (0, -1) + tuple(range(1, len(x.size())))
    # x_onehot = x_onehot.permute(permute_order)
    log_x = torch.log(x_onehot.float().clamp(min=1e-30))
    return log_x

关于 q_v_sample 函数,在对数空间,添加噪音到节点特征, 代码如下:

    def q_v_sample(self, log_v0, t, batch):
        '''
        在对数空间,添加噪音到节点特征
        '''
        # 添加噪音到对数概率
        log_qvt_v0 = self.q_v_pred(log_v0, t, batch)
        # 采样节点特征各类别
        sample_index = log_sample_categorical(log_qvt_v0)
        # 类别概率转化为 one-hot (去对数化)
        log_sample = index_to_log_onehot(sample_index, self.num_classes)
        return sample_index, log_sample

其中,self.q_v_pred 实现对数空间的添加节点特征噪音,代码如下:

    def q_v_pred(self, log_v0, t, batch):
        # compute q(vt | v0)
        log_cumprod_alpha_t = extract(self.log_alphas_cumprod_v, t, batch)
        log_1_min_cumprod_alpha = extract(self.log_one_minus_alphas_cumprod_v, t, batch)
        # 节点特征噪音为 1/类别数
        log_probs = log_add_exp(
            log_v0 + log_cumprod_alpha_t,
            log_1_min_cumprod_alpha - np.log(self.num_classes)
        )
        return log_probs

index_to_log_onehot 前面已经介绍过,不再重复。log_sample_categorical 代码如下:

def log_sample_categorical(logits):
    '''
    基于 Gumbel 分布 的离散分布采样方法。
    它通过对每个类别的 logits 添加 Gumbel 噪声,
    然后选择具有最大值的类别。这样的方法可以避免直接计算 softmax,
    并且由于 Gumbel 分布的性质,这种采样方式是从 logits 所表示的分布中采样的
    '''
    uniform = torch.rand_like(logits)
    gumbel_noise = -torch.log(-torch.log(uniform + 1e-30) + 1e-30)
    sample_index = (gumbel_noise + logits).argmax(dim=-1)
    # sample_onehot = F.one_hot(sample, self.num_classes)
    # log_sample = index_to_log_onehot(sample, self.num_classes)
    return sample_index

(5) 基于扰动的小分子坐标和小分子节点类型,以及真实的蛋白坐标和节点类型,调用 forward 函数预测节点类型和坐标,代码如下:

 
        preds = self(
            protein_pos=protein_pos,
            protein_v=protein_v,
            batch_protein=batch_protein,

            init_ligand_pos=ligand_pos_perturbed,
            init_ligand_v=ligand_v_perturbed,
            batch_ligand=batch_ligand,
            time_step=time_step
        )

        pred_ligand_pos, pred_ligand_v = preds['pred_ligand_pos'], preds['pred_ligand_v']

关于 forward 函数, 将时间步 t 与扰动后的小分子节点特征合并,然后小分子和蛋白的节点特征进行嵌入,随后,输入到神经网络 refine_net 中预测小分子的初始的坐标和节点特征。代码如下:

   def forward(self, protein_pos, protein_v, batch_protein, init_ligand_pos, init_ligand_v, batch_ligand,
                time_step=None, return_all=False, fix_x=False, guide=None):
        
        batch_size = batch_protein.max().item() + 1 # 批次中图的数量
        init_ligand_v = F.one_hot(init_ligand_v, self.num_classes).float() # 小分子节点类型转为 one-hot 
        ## time embedding - currently not useful, generally passed as zero only
        ## if used, it is concatenated to the ligand features
        # 时间步 t 进行嵌入,然后与小分子的节点类型进行concat
        if self.time_emb_dim > 0:
            if self.time_emb_mode == 'simple':
                input_ligand_feat = torch.cat([
                    init_ligand_v,
                    (time_step / self.num_timesteps)[batch_ligand].unsqueeze(-1)
                ], -1)
            elif self.time_emb_mode == 'sin':
                time_feat = self.time_emb(time_step)
                input_ligand_feat = torch.cat([init_ligand_v, time_feat], -1)
            else:
                raise NotImplementedError
        else:
            input_ligand_feat = init_ligand_v

        ## convert one-hot features into the embedding space
        # 蛋白和小分子的节点类型进行嵌入
        h_protein = self.protein_atom_emb(protein_v)
        init_ligand_h = self.ligand_atom_emb(input_ligand_feat)

        ## add 0 to the end of hidden embedding of protein for every atom
        ## add 1 to the end of hidden embedding of protein for every atom
        if self.config.node_indicator:
            # 标记 哪个节点是属于蛋白,哪个节点属于小分子
            h_protein = torch.cat([h_protein, torch.zeros(len(h_protein), 1).to(h_protein)], -1)
            init_ligand_h = torch.cat([init_ligand_h, torch.ones(len(init_ligand_h), 1).to(h_protein)], -1)

        ## combines hidden states of protein and ligand into one set
        ## mask_ligand is used to keep track of which atom belongs to which type
        ## needed for forward pass of refine_net (uni_transformer)
        # 合并蛋白和小分子的坐标、节点特征以及批次信息,按照批次信息进行重排顺序
        h_all, pos_all, batch_all, mask_ligand = compose_context(
            h_protein=h_protein,
            h_ligand=init_ligand_h,
            pos_protein=protein_pos,
            pos_ligand=init_ligand_pos,
            batch_protein=batch_protein,
            batch_ligand=batch_ligand,
        )

        # 预测真实的 x 和 h
        if guide is None:
            # 没有引导的情况下,TargetDiff
            outputs = self.refine_net(h_all, pos_all, mask_ligand, batch_all, return_all=return_all, fix_x=fix_x)
        else:
            # 在有梯度引导的情况下,TagMol
            outputs = self.refine_net.forward_guided(
                h_all, pos_all, mask_ligand, batch_all, return_all=return_all, fix_x=fix_x, 
                guide=guide
            )
        final_pos, final_h = outputs['x'], outputs['h']
        ## final positions of ligand only is needed (protein positions are not changed or used)
        final_ligand_pos, final_ligand_h = final_pos[mask_ligand], final_h[mask_ligand]
        ## predict classes of atom-categories at the end of all the layers
        # 进一步预测小分子的节点类型
        final_ligand_v = self.v_inference(final_ligand_h)

        # 输出内容
        preds = {
            'pred_ligand_pos': final_ligand_pos,
            'pred_ligand_v': final_ligand_v,
            'final_h': final_h,
            'final_ligand_h': final_ligand_h
        }
        if return_all:
            final_all_pos, final_all_h = outputs['all_x'], outputs['all_h']
            final_all_ligand_pos = [pos[mask_ligand] for pos in final_all_pos]
            final_all_ligand_v = [self.v_inference(h[mask_ligand]) for h in final_all_h]
            preds.update({
                'layer_pred_ligand_pos': final_all_ligand_pos,
                'layer_pred_ligand_v': final_all_ligand_v
            })
        return preds

神经网络 refine_net 对应的是 uni_o2 模型,将在后面介绍。

(6) 提取小分子的坐标/噪音,然后计算坐标部分 x_0 的损失,仅计算 mse 损失,如下:

        pred_ligand_pos, pred_ligand_v = preds['pred_ligand_pos'], preds['pred_ligand_v']
        pred_pos_noise = pred_ligand_pos - ligand_pos_perturbed
        # atom position
        if self.model_mean_type == 'noise':
            # 神经网络预测的是噪音
            # 根据噪音,计算真实的坐标
            pos0_from_e = self._predict_x0_from_eps(
                xt=ligand_pos_perturbed, eps=pred_pos_noise, t=time_step, batch=batch_ligand)
            # 计算后验 t-1 的小分子坐标 (后面没使用)
            pos_model_mean = self.q_pos_posterior(
                x0=pos0_from_e, xt=ligand_pos_perturbed, t=time_step, batch=batch_ligand)
        elif self.model_mean_type == 'C0':
            # 神经网络预测是的真实的小分子的坐标和节点类型
            # 计算后验 t-1 的小分子坐标 (后面没使用)
            pos_model_mean = self.q_pos_posterior(
                x0=pred_ligand_pos, xt=ligand_pos_perturbed, t=time_step, batch=batch_ligand)
        else:
            raise ValueError

注,虽然,在其中使用 self.q_pos_posterior 函数输入 X_0 和 X_t 计算了坐标 x_t 的 前一步 x_t-1,但是,并没有计算 x_t-1 的先验真实值与后验预测值的 KL 散度作为损失,而是直接使用 x_0 的 mse 作为损失。self.q_pos_posterior 函数与上文提及的公式对应:

x_{t-1} = \frac{\sqrt{\alpha_{cumprod, t-1}} \beta_{t}}{1-\alpha_{cumprod,t}} x_{0} + \frac{\sqrt{\alpha_{t}}(1-\alpha_{cumprod,t-1})}{1-\alpha_{cumprod,t}} x_{t} + \sqrt{\beta_{t}} \epsilon

self.q_pos_posterior 函数的代码如下:

    def q_pos_posterior(self, x0, xt, t, batch):
        # 计算 X_t-1
        # Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0)
        pos_model_mean = extract(self.posterior_mean_c0_coef, t, batch) * x0 + \
                         extract(self.posterior_mean_ct_coef, t, batch) * xt
        return pos_model_mean

(7) 计算先验的和后验的节点类型 v_t-1 的 KL 损失

        # atom type loss 计算节点类型的 KL 损失
        log_ligand_v_recon = F.log_softmax(pred_ligand_v, dim=-1) # 预测的小分子节点类型 one-hot
        # 后验的 V_t-1 (预测的)
        log_v_model_prob = self.q_v_posterior(log_ligand_v_recon, log_ligand_vt, time_step, batch_ligand)
        # 先验的 V_t-1 (先验的)
        log_v_true_prob = self.q_v_posterior(log_ligand_v0, log_ligand_vt, time_step, batch_ligand)
        # 计算节点的类型 V_t-1 的 KL 损失
        kl_v = self.compute_v_Lt(log_v_model_prob=log_v_model_prob, log_v0=log_ligand_v0,
                                 log_v_true_prob=log_v_true_prob, t=time_step, batch=batch_ligand)
        loss_v = torch.mean(kl_v)

其中,self.q_v_posterior 分别计算了先验的和后验的 V_t-1 的概率。需要注意的是,由于节点类型时离散变量,而使用了如下公式:

q(v_{t-1}|v_{t},v_{0}) = \frac{q(v_{t}|v_{t-1},v_{0}) \ast q(v_{t-1}|v_{0})}{\sum_{ }^{ }{q(v_{t}|v_{t-1},v_{0}) \ast q(v_{t-1}|v_{0}) }}

self.q_v_posterior 代码如下。(self.q_v_pred 上文已经介绍过,不再重复)

    def q_v_posterior(self, log_v0, log_vt, t, batch):
        # 对用的公式,q(vt-1 | vt, v0) = q(vt | vt-1, x0) * q(vt-1 | x0) / q(vt | x0)
        # 注意这里没有使用 坐标类似的公式
        t_minus_1 = t - 1
        # Remove negative values, will not be used anyway for final decoder 去除负值
        t_minus_1 = torch.where(t_minus_1 < 0, torch.zeros_like(t_minus_1), t_minus_1)
        # q(vt-1 | x0)
        log_qvt1_v0 = self.q_v_pred(log_v0, t_minus_1, batch)
        # self.q_v_pred_one_timestep 计算 q(vt | vt-1, x0) 
        unnormed_logprobs = log_qvt1_v0 + self.q_v_pred_one_timestep(log_vt, t, batch)
        # 归一化 V_t-1 的概率, 即 q(vt-1 | vt, v0)
        log_vt1_given_vt_v0 = unnormed_logprobs - torch.logsumexp(unnormed_logprobs, dim=-1, keepdim=True)
        return log_vt1_given_vt_v0

其中,self.q_v_pred 计算 q(vt-1 | x0), self.q_v_pred_one_timestep 实现一步的扩散 q(v_t | v_t-1) 。

但是似乎,self.q_v_pred_one_timestep(log_vt, t, batch) 有错误,应该改为 self.q_v_pred_one_timestep(log_qvt1_v0 , t, batch),但是不清楚这是不是真的错误,以及错误的影响。要试一下。

关于 self.compute_v_Lt, 计算真实的 V_t-1和预测的V_t-1 之间的 KL 散度(当t 不等于0 时)。当 t 等于0时,则计算真实的 V_t-1和预测的V_t-1 之间的 NLL,保证生成的分子与真实分子一致。代码如下:

    def compute_v_Lt(self, log_v_model_prob, log_v0, log_v_true_prob, t, batch):
        '''
        KL 散度(KL divergence)或负对数似然(Negative Log Likelihood, NLL)
        '''
        kl_v = categorical_kl(log_v_true_prob, log_v_model_prob)  # [num_atoms, ] # KL 散度
        decoder_nll_v = -log_categorical(log_v0, log_v_model_prob)  # L0 # 负对数似然
        assert kl_v.shape == decoder_nll_v.shape
        mask = (t == 0).float()[batch] # 如果 t=0 则不计算损失 KL 损失只计算 NLL
        loss_v = scatter_mean(mask * decoder_nll_v + (1. - mask) * kl_v, batch, dim=0) # KL  散度或者 NLL
        return loss_v

其中,categorical_kl 和 log_categoricald 的代码如下:

def log_categorical(log_x_start, log_prob):
    return (log_x_start.exp() * log_prob).sum(dim=1)

def categorical_kl(log_prob1, log_prob2):
    kl = (log_prob1.exp() * (log_prob1 - log_prob2)).sum(dim=1)
    return kl

(8) 计算总损失,并返回预测的小分子坐标、节点类型,以及噪音和损失。代码如下:(注意,坐标的 mse 的权重为 100, 按照配置文件)

        # 总损失为:坐标 X_0 的 mse(默认权重 100), 以及 V_t-1 的 KL 损失 
        loss = loss_pos + loss_v * self.loss_v_weight

        return {
            'loss_pos': loss_pos,
            'loss_v': loss_v,
            'loss': loss,
            'x0': ligand_pos,
            'pred_ligand_pos': pred_ligand_pos,
            'pred_ligand_v': pred_ligand_v,
            'pred_pos_noise': pred_pos_noise,
            'ligand_v_recon': F.softmax(pred_ligand_v, dim=-1)
        }

完整的 get_diffusion_loss 函数代码如下:

def get_diffusion_loss(
            self, protein_pos, protein_v, batch_protein, ligand_pos, ligand_v, batch_ligand, time_step=None
    ):
        num_graphs = batch_protein.max().item() + 1 # 蛋白/小分子图数量
        protein_pos, ligand_pos, _ = center_pos(
            protein_pos, ligand_pos, batch_protein, batch_ligand, mode=self.center_pos_mode) # 坐标中心归0

        # 1. sample noise levels # 设置时间 t, 采样/按参数指定
        if time_step is None:
            # 采样
            time_step, pt = self.sample_time(num_graphs, protein_pos.device, self.sample_time_method)
        else:
            # 按参数指定
            pt = torch.ones_like(time_step).float() / self.num_timesteps
        a = self.alphas_cumprod.index_select(0, time_step)  # (num_graphs, )

        # 2. perturb pos and v
        # 扩展 α_cumprod,t 的维度与分子节点数相同
        a_pos = a[batch_ligand].unsqueeze(-1)  # (num_ligand_atoms, 1) 
        # 分子的噪音
        pos_noise = torch.zeros_like(ligand_pos) 
        pos_noise.normal_()
        # 扰动小分子的坐标,公式:Xt = a.sqrt() * X0 + (1-a).sqrt() * eps
        ligand_pos_perturbed = a_pos.sqrt() * ligand_pos + (1.0 - a_pos).sqrt() * pos_noise  # pos_noise * std
        # 节点类型 one-hot 的 log 概率,然后添加噪音扰动, 公式:Vt = a * V0 + (1-a) / K ,
        log_ligand_v0 = index_to_log_onehot(ligand_v, self.num_classes)
        ligand_v_perturbed, log_ligand_vt = self.q_v_sample(log_ligand_v0, time_step, batch_ligand)

        # 3. forward-pass NN, feed perturbed pos and v, output noise
        # forward 函数预测真实的小分子节点类型和坐标
        preds = self(
            protein_pos=protein_pos,
            protein_v=protein_v,
            batch_protein=batch_protein,

            init_ligand_pos=ligand_pos_perturbed,
            init_ligand_v=ligand_v_perturbed,
            batch_ligand=batch_ligand,
            time_step=time_step
        )

        pred_ligand_pos, pred_ligand_v = preds['pred_ligand_pos'], preds['pred_ligand_v']
        pred_pos_noise = pred_ligand_pos - ligand_pos_perturbed
        # atom position
        if self.model_mean_type == 'noise':
            # 神经网络预测的是噪音
            # 根据噪音,计算真实的坐标
            pos0_from_e = self._predict_x0_from_eps(
                xt=ligand_pos_perturbed, eps=pred_pos_noise, t=time_step, batch=batch_ligand)
            # 计算后验 t-1 的小分子坐标 (后面没使用)
            pos_model_mean = self.q_pos_posterior(
                x0=pos0_from_e, xt=ligand_pos_perturbed, t=time_step, batch=batch_ligand)
        elif self.model_mean_type == 'C0':
            # 神经网络预测是的真实的小分子的坐标和节点类型
            # 计算后验 t-1 的小分子坐标 (后面没使用)
            pos_model_mean = self.q_pos_posterior(
                x0=pred_ligand_pos, xt=ligand_pos_perturbed, t=time_step, batch=batch_ligand)
        else:
            raise ValueError

        # atom pos loss 计算节点坐标的 MSE 损失
        if self.model_mean_type == 'C0':
            target, pred = ligand_pos, pred_ligand_pos
        elif self.model_mean_type == 'noise':
            target, pred = pos_noise, pred_pos_noise
        else:
            raise ValueError
        loss_pos = scatter_mean(((pred - target) ** 2).sum(-1), batch_ligand, dim=0)
        loss_pos = torch.mean(loss_pos)

        # atom type loss 计算节点类型的 KL 损失
        log_ligand_v_recon = F.log_softmax(pred_ligand_v, dim=-1) # 预测的小分子节点类型 one-hot
        # 后验的 V_t-1 (预测的)
        log_v_model_prob = self.q_v_posterior(log_ligand_v_recon, log_ligand_vt, time_step, batch_ligand)
        # 先验的 V_t-1 (先验的)
        log_v_true_prob = self.q_v_posterior(log_ligand_v0, log_ligand_vt, time_step, batch_ligand)
        # 计算节点的类型 V_t-1 的 KL 损失
        kl_v = self.compute_v_Lt(log_v_model_prob=log_v_model_prob, log_v0=log_ligand_v0,
                                 log_v_true_prob=log_v_true_prob, t=time_step, batch=batch_ligand)
        loss_v = torch.mean(kl_v)
        # 总损失为:坐标 X_0 的 mse(默认权重 100), 以及 V_t-1 的 KL 损失 
        loss = loss_pos + loss_v * self.loss_v_weight

        return {
            'loss_pos': loss_pos,
            'loss_v': loss_v,
            'loss': loss,
            'x0': ligand_pos,
            'pred_ligand_pos': pred_ligand_pos,
            'pred_ligand_v': pred_ligand_v,
            'pred_pos_noise': pred_pos_noise,
            'ligand_v_recon': F.softmax(pred_ligand_v, dim=-1)
        }

总结一下 get_diffusion_loss 函数,调用 forward 函数预测真实的坐标和节点类型,然后分别计算了坐标的 mse 损失,以及节点类型的 q(v_t-1 | v_t, v_0) 的 真实的和基于预测值的 KL 损失。在权重上,坐标的 mse 损失的权重是 KL 损失的 100 倍。

4.2.3 sample_diffusion 函数

sample_diffusion 函数是 TargetDiff 的分子生成函数(注:不是 TagMol 的分子生生成函数),用于逐步对初始化的小分子坐标和节点(均含有噪音)进行去噪。

sample_diffusion 函数输入 初始化的小分子坐标 init_ligand_pos 和小分子节点特征 init_ligand_v 、蛋白坐标 protein_pos、蛋白节点 protein_v、小分子mask 和蛋白的 mask  batch_protein, 输入小分子去噪过程中的节点特征 v0_pred_traj 和坐标 pos_traj、去噪过程中的神经网络预测的 v_0 pos0_traj 和 x_0  v0_pred_traj,以及最终生成的小分子坐标 ligand_pos 和节点特征 ligand_v。

这一部分 sample_diffusion 函数的代码与 2.6 部分 sample_multi_guided_diffusion 函数非常相似。只是少了梯度计算,并更新原子坐标的过程。

首先是基本设置,包括:去噪步数、批次中图数、坐标去中心、初始化去噪过程记录列表。

然后,就是在 for 循环里面,按照 T 的倒序,不断输入当前的 X_t 和 V_t 使用训练好的神经网络面预测 X_0 和 V_0,然后基于 X_0 和 V_0 以及当前的 X_t 和 V_t 使用去噪公式计算 X_t-1 和V_t-1,然后使用 X_t-1 和 V_t-1 替换 X_t 和 V_t。以此反复进行去噪,直到 T=0.

完整代码如下:

    @torch.no_grad()
    def sample_diffusion(self, protein_pos, protein_v, batch_protein,
                         init_ligand_pos, init_ligand_v, batch_ligand,
                         num_steps=None, center_pos_mode=None, pos_only=False):
        
        if num_steps is None:
            num_steps = self.num_timesteps
        # 批次中图的数量
        num_graphs = batch_protein.max().item() + 1

        ## Shifts the origin to the centre of mass of protein
        ## new protein positions, new ligand position and the difference from original position is in the offset
        # 去坐标中心
        protein_pos, init_ligand_pos, offset = center_pos(
            protein_pos, init_ligand_pos, batch_protein, batch_ligand, mode=center_pos_mode)

        pos_traj, v_traj = [], []
        v0_pred_traj, vt_pred_traj, pos0_traj = [], [], []
        ligand_pos, ligand_v = init_ligand_pos, init_ligand_v
        ## time sequence - going from 1000 to 1000 - num_steps
        time_seq = list(reversed(range(self.num_timesteps - num_steps, self.num_timesteps)))
        # 逐步去噪
        for i in tqdm(time_seq, desc='sampling', total=len(time_seq)):
            # 时间 t 扩散至图的数量
            t = torch.full(size=(num_graphs,), fill_value=i, dtype=torch.long, device=protein_pos.device)
            # 预测 v_0 和 x_0
            preds = self(
                protein_pos=protein_pos,
                protein_v=protein_v,
                batch_protein=batch_protein,

                init_ligand_pos=ligand_pos,
                init_ligand_v=ligand_v,
                batch_ligand=batch_ligand,
                time_step=t
            )
            # Compute posterior mean and variance 提取坐标预测值
            if self.model_mean_type == 'noise':
                pred_pos_noise = preds['pred_ligand_pos'] - ligand_pos
                pos0_from_e = self._predict_x0_from_eps(xt=ligand_pos, eps=pred_pos_noise, t=t, batch=batch_ligand)
                v0_from_e = preds['pred_ligand_v']
            elif self.model_mean_type == 'C0':
                pos0_from_e = preds['pred_ligand_pos']
                v0_from_e = preds['pred_ligand_v']
            else:
                raise ValueError

            # 计算坐标的 X_t-1
            pos_model_mean = self.q_pos_posterior(x0=pos0_from_e, xt=ligand_pos, t=t, batch=batch_ligand)
            # t 时刻的坐标方差
            pos_log_variance = extract(self.posterior_logvar, t, batch_ligand)
            # no noise when t == 0
            nonzero_mask = (1 - (t == 0).float())[batch_ligand].unsqueeze(-1)
            # 采样 X_t-1 
            ligand_pos_next = pos_model_mean + nonzero_mask * (0.5 * pos_log_variance).exp() * torch.randn_like(
                ligand_pos)
            ligand_pos = ligand_pos_next

            if not pos_only:
                # 预测的 V_0 (对数概率)
                log_ligand_v_recon = F.log_softmax(v0_from_e, dim=-1)
                # one-hot V_0 
                log_ligand_v = index_to_log_onehot(ligand_v, self.num_classes)
                # V_t-1
                log_model_prob = self.q_v_posterior(log_ligand_v_recon, log_ligand_v, t, batch_ligand)
                # 采样 V_t-1
                ligand_v_next = log_sample_categorical(log_model_prob)

                v0_pred_traj.append(log_ligand_v_recon.clone().cpu())
                vt_pred_traj.append(log_model_prob.clone().cpu())
                ligand_v = ligand_v_next

            # 记录采样轨迹
            ori_ligand_pos0 = pos0_from_e + offset[batch_ligand] # 去噪过程中神经网络预测的 X_0
            ori_ligand_pos = ligand_pos + offset[batch_ligand] # 去噪过程中采样的的 X_t
            pos_traj.append(ori_ligand_pos.clone().cpu())
            pos0_traj.append(ori_ligand_pos0.clone().cpu())
            v_traj.append(ligand_v.clone().cpu()) # 去噪过程中采样的的 v_t

        ligand_pos = ligand_pos + offset[batch_ligand]
        return {
            'pos': ligand_pos,
            'v': ligand_v,
            'pos_traj': pos_traj,
            'pos0_traj': pos0_traj,
            'v_traj': v_traj,
            'v0_traj': v0_pred_traj,
            'vt_traj': vt_pred_traj
        }

至此, TargetDiff 模型(即,主模型 ScorePosNet3D )就已经全部介绍完成。

五、梯度引导模型 (DockGuideNet3D)

在 3.1 部分,训练梯度引导模型中,调用了 DockGuideNet3D 的 get_loss 方法计算梯度引导模型的损失。在分子生成的 2.3 梯度引导的分子分生成中,调用了 DockGuideNet3D 的 get_gradients_guide 方法来计算梯度。

DockGuideNet3D 在训练过程中,使用扰动后的小分子坐标和节点类型,以Binding Affinity 作为训练的 loss 目标。 小分子坐标和节点类型的扰动模式与主模型 ScorePosNet3D 完全相同。

下面,将按照__init__ 、get_loss、get_gradients_guide 对 DockGuideNet3D 进行全面的介绍。(注:在 GitHub 中,DockGuideNet3D 的代码很长,因为直接继承了很多 ScorePosNet3D 的代码修改而来,里面包含了很多没有使用到的函数。此外,由于在时间t采样等部分,与 ScorePosNet3D 完全相同,因此,这里对 DockGuideNet3D 的介绍会简单一些。)

5.1 __init__ 

(1)神经网络的预测模式、坐标 mse 损失权重、时间 t 采样的方式、

        # 神经网络预测 噪音还是真实值,默认是 C0 真实值
        self.model_mean_type = config.model_mean_type  # ['noise', 'C0']
        self.loss_v_weight = config.loss_v_weight
        
        # 时间 t 的 采样,重要性采样 importance 或者 对称性采样 symmetric
        self.sample_time_method = config.sample_time_method  # ['importance', 'symmetric']

(2)坐标的噪音调度器 β及其参数(α,α累乘,根号下1-α,计算 x_t-1 的系数等)

        # 噪音调度器 β 每一步添加噪音的比例; α (即,1-β) 每一步保留原信息的比例
        if config.beta_schedule == 'cosine':
            alphas = cosine_beta_schedule(config.num_diffusion_timesteps, config.pos_beta_s) ** 2
            # print('cosine pos alpha schedule applied!')
            betas = 1. - alphas
        else:
            betas = get_beta_schedule(
                beta_schedule=config.beta_schedule,
                beta_start=config.beta_start,
                beta_end=config.beta_end,
                num_diffusion_timesteps=config.num_diffusion_timesteps,
            )
            alphas = 1. - betas
        alphas_cumprod = np.cumprod(alphas, axis=0) # α的累乘
        alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) # α的累乘,从t=0开始

        self.betas = to_torch_const(betas)
        self.num_timesteps = self.betas.size(0)
        self.alphas_cumprod = to_torch_const(alphas_cumprod) # α的累乘
        self.alphas_cumprod_prev = to_torch_const(alphas_cumprod_prev) # α的累乘,从t=0开始

        # 从 x_0 到 x_t 的过程中, x_0 的比例
        self.sqrt_alphas_cumprod = to_torch_const(np.sqrt(alphas_cumprod)) # 根号下 α的累乘,从t=0开始
        # 从 x_0 到 x_t 的过程中, 噪音的比例
        self.sqrt_one_minus_alphas_cumprod = to_torch_const(np.sqrt(1. - alphas_cumprod)) # 根号下(1-α的累乘)
        self.sqrt_recip_alphas_cumprod = to_torch_const(np.sqrt(1. / alphas_cumprod))
        self.sqrt_recipm1_alphas_cumprod = to_torch_const(np.sqrt(1. / alphas_cumprod - 1))

        # 每个时刻 t 下的 x_t 的概率(不确定性/方差)
        posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
        # 计算 x_{t-1} 的 x_0 的系数
        self.posterior_mean_c0_coef = to_torch_const(betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
        # 计算 x_{t-1} 的 x_t 的系数
        self.posterior_mean_ct_coef = to_torch_const(
            (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))

        self.posterior_var = to_torch_const(posterior_variance) # 每个时刻 t 下的 x_t 的概率(不确定性/方差)
        self.posterior_logvar = to_torch_const(np.log(np.append(self.posterior_var[1], self.posterior_var[1:]))) # 对数方差

(3) 节点特征相关的噪音调度器

        # # 节点类型的噪音调度器
        # 计算 α_t
        if config.v_beta_schedule == 'cosine':
            alphas_v = cosine_beta_schedule(self.num_timesteps, config.v_beta_s)
            # print('cosine v alpha schedule applied!')
        else:
            raise NotImplementedError
        # 计算 log(α_t)
        log_alphas_v = np.log(alphas_v) 
        # 计算 log(α_t * α_t-1 .... α_0), 相当于α_t 的累积
        log_alphas_cumprod_v = np.cumsum(log_alphas_v)
        self.log_alphas_v = to_torch_const(log_alphas_v)
        # 计算 log(1-α_t)
        self.log_one_minus_alphas_v = to_torch_const(log_1_min_a(log_alphas_v))
        # 计算 log((1-α_t) * (1-α_t-1) .... (1-α_0)), 相当于1-α_t 的累积
        self.log_alphas_cumprod_v = to_torch_const(log_alphas_cumprod_v)
        self.log_one_minus_alphas_cumprod_v = to_torch_const(log_1_min_a(log_alphas_cumprod_v))

        # 设置β时的权重。只有在 config.sample_time_method == inprotance 时成立 
        self.register_buffer('Lt_history', torch.zeros(self.num_timesteps))
        self.register_buffer('Lt_count', torch.zeros(self.num_timesteps))

注:坐标和节点相关的噪音调度器中,有大量的参数是 DockGuideNet3D 梯度引导模型不会使用的,但是在 GitHub 中仍然保留了,这些属于多余的代码。

(4) 神经网络的配置。包括:隐藏层的维度、小分子的节点类型种类数、是否包含小分子/蛋白的mask、蛋白的嵌入层、小分子和时间步 t 的合并嵌入、get_refine_net 函数获取并实例化神经网络、推理 binding affinity 层、辅助断言(在开发模型测试使用,此处可忽略)。

        # model definition 神经网络的定义
        self.hidden_dim = config.hidden_dim # 隐藏层维度
        self.num_classes = ligand_atom_feature_dim # 小分子节点类型数
        if self.config.node_indicator:
            # 如果包含小分子/蛋白的标记mask
            emb_dim = self.hidden_dim - 1
        else:
            emb_dim = self.hidden_dim

        # atom embedding  原子 embeding 层
        self.protein_atom_emb = nn.Linear(protein_atom_feature_dim, emb_dim)

        # center pos 坐标中心
        self.center_pos_mode = config.center_pos_mode  # ['none', 'protein']

        # time embedding 时间嵌入层,与 小分子节点特征 embeding 层
        self.time_emb_dim = config.time_emb_dim
        self.time_emb_mode = config.time_emb_mode  # ['simple', 'sin']
        if self.time_emb_dim > 0:
            if self.time_emb_mode == 'simple':
                self.ligand_atom_emb = nn.Linear(ligand_atom_feature_dim + 1, emb_dim)
            elif self.time_emb_mode == 'sin':
                self.time_emb = nn.Sequential(
                    SinusoidalPosEmb(self.time_emb_dim),
                    nn.Linear(self.time_emb_dim, self.time_emb_dim * 4),
                    nn.GELU(),
                    nn.Linear(self.time_emb_dim * 4, self.time_emb_dim)
                )
                self.ligand_atom_emb = nn.Linear(ligand_atom_feature_dim + self.time_emb_dim, emb_dim)
            else:
                raise NotImplementedError
        else:
            self.ligand_atom_emb = nn.Linear(ligand_atom_feature_dim, emb_dim)

        # 实例化神经网络
        self.refine_net_type = config.model_type
        self.refine_net = get_refine_net(self.refine_net_type, config)
        # 推理 binding affinity 层
        self.dock_inference = nn.Sequential(
            nn.Linear(self.hidden_dim, self.hidden_dim),
            ShiftedSoftplus(),
            nn.Linear(self.hidden_dim, 1),
        )

        # TODO: show warning if drop_protein_in_guide not in config
        self.drop_protein_in_guide = config.get("drop_protein_in_guide", False)
        assert isinstance(self.drop_protein_in_guide, bool), f"`drop_protein_in_guide` should be True or False. Got: {self.drop_protein_in_guide}"

        # TODO: show warning if maximize_property not in config
        self.maximize_property = config.get("maximize_property", False)
        assert isinstance(self.maximize_property, bool), f"`maximize_property` should be True or False. Got: {self.maximize_property}"

        # 推理 binding affinity 层的输出类型
        self.problem_type = config.get("problem_type", "regression")
        assert self.problem_type in ["regression", "classification"], f"`problem_type` should be 'regression' or 'classification'. Got: {self.problem_type}"

关于 get_refine_net 函数,实例化一个神经网路。这个神经网络名为 egnn, 按照配置文件的定义,包含了9层,维度为128,距离的高斯维度为20的 EGNN 网络。代码如下:

def get_refine_net(refine_net_type, config):
    if refine_net_type == 'uni_o2':
        raise NotImplementedError(refine_net_type)
    elif refine_net_type == 'egnn':
        refine_net = EGNN(
            num_layers=config.num_layers,
            hidden_dim=config.hidden_dim,
            edge_feat_dim=config.edge_feat_dim,
            num_r_gaussian=config.num_r_gaussian,
            k=config.knn,
            cutoff=config.r_max,
            update_x=config.update_x
        )
    else:
        raise ValueError(refine_net_type)
    return refine_net

需要特别注意的是,DockGuideNet3D 的坐标和节点类型的噪音调度器与主模型 ScorePosNet3D 完全一致。完整的 __init__ 代码如下:

  def __init__(self, config, protein_atom_feature_dim, ligand_atom_feature_dim):
        super().__init__()
        self.config = config

        # variance schedule
        # 神经网络预测 噪音还是真实值,默认是 C0 真实值
        self.model_mean_type = config.model_mean_type  # ['noise', 'C0']
        self.loss_v_weight = config.loss_v_weight

       # 时间 t 的 采样,重要性采样 importance 或者 对称性采样 symmetric
        self.sample_time_method = config.sample_time_method  # ['importance', 'symmetric']

        # 噪音调度器 β 每一步添加噪音的比例; α (即,1-β) 每一步保留原信息的比例
        if config.beta_schedule == 'cosine':
            alphas = cosine_beta_schedule(config.num_diffusion_timesteps, config.pos_beta_s) ** 2
            # print('cosine pos alpha schedule applied!')
            betas = 1. - alphas
        else:
            betas = get_beta_schedule(
                beta_schedule=config.beta_schedule,
                beta_start=config.beta_start,
                beta_end=config.beta_end,
                num_diffusion_timesteps=config.num_diffusion_timesteps,
            )
            alphas = 1. - betas
        alphas_cumprod = np.cumprod(alphas, axis=0) # α的累乘
        alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) # α的累乘,从t=0开始

        self.betas = to_torch_const(betas)
        self.num_timesteps = self.betas.size(0)
        self.alphas_cumprod = to_torch_const(alphas_cumprod) # α的累乘
        self.alphas_cumprod_prev = to_torch_const(alphas_cumprod_prev) # α的累乘,从t=0开始

        # calculations for diffusion q(x_t | x_{t-1}) and others,添加噪音过程
        # 从 x_0 到 x_t 的过程中, x_0 的比例
        self.sqrt_alphas_cumprod = to_torch_const(np.sqrt(alphas_cumprod)) # 根号下 α的累乘,从t=0开始
        # 从 x_0 到 x_t 的过程中, 噪音的比例
        self.sqrt_one_minus_alphas_cumprod = to_torch_const(np.sqrt(1. - alphas_cumprod)) # 根号下(1-α的累乘)
        self.sqrt_recip_alphas_cumprod = to_torch_const(np.sqrt(1. / alphas_cumprod))
        self.sqrt_recipm1_alphas_cumprod = to_torch_const(np.sqrt(1. / alphas_cumprod - 1))

        # calculations for posterior q(x_{t-1} | x_t, x_0)
        # 每个时刻 t 下的 x_t 的概率(不确定性/方差)
        posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
        # 计算 x_{t-1} 的 x_0 的系数
        self.posterior_mean_c0_coef = to_torch_const(betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
        # 计算 x_{t-1} 的 x_t 的系数
        self.posterior_mean_ct_coef = to_torch_const(
            (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))
        # log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
        self.posterior_var = to_torch_const(posterior_variance) # 每个时刻 t 下的 x_t 的概率(不确定性/方差)
        self.posterior_logvar = to_torch_const(np.log(np.append(self.posterior_var[1], self.posterior_var[1:]))) # 对数方差

        # atom type diffusion schedule in log space 
        # # 节点类型的噪音调度器
        # 计算 α_t
        if config.v_beta_schedule == 'cosine':
            alphas_v = cosine_beta_schedule(self.num_timesteps, config.v_beta_s)
            # print('cosine v alpha schedule applied!')
        else:
            raise NotImplementedError
        # 计算 log(α_t)
        log_alphas_v = np.log(alphas_v) 
        # 计算 log(α_t * α_t-1 .... α_0), 相当于α_t 的累积
        log_alphas_cumprod_v = np.cumsum(log_alphas_v)
        self.log_alphas_v = to_torch_const(log_alphas_v)
        # 计算 log(1-α_t)
        self.log_one_minus_alphas_v = to_torch_const(log_1_min_a(log_alphas_v))
        # 计算 log((1-α_t) * (1-α_t-1) .... (1-α_0)), 相当于1-α_t 的累积
        self.log_alphas_cumprod_v = to_torch_const(log_alphas_cumprod_v)
        self.log_one_minus_alphas_cumprod_v = to_torch_const(log_1_min_a(log_alphas_cumprod_v))

        # 设置β时的权重。只有在 config.sample_time_method == inprotance 时成立 
        self.register_buffer('Lt_history', torch.zeros(self.num_timesteps))
        self.register_buffer('Lt_count', torch.zeros(self.num_timesteps))

        # model definition 神经网络的定义
        self.hidden_dim = config.hidden_dim # 隐藏层维度
        self.num_classes = ligand_atom_feature_dim # 小分子节点类型数
        if self.config.node_indicator:
            # 如果包含小分子/蛋白的标记mask
            emb_dim = self.hidden_dim - 1
        else:
            emb_dim = self.hidden_dim

        # atom embedding  原子 embeding 层
        self.protein_atom_emb = nn.Linear(protein_atom_feature_dim, emb_dim)

        # center pos 坐标中心
        self.center_pos_mode = config.center_pos_mode  # ['none', 'protein']

        # time embedding 时间嵌入层,与 小分子节点特征 embeding 层
        self.time_emb_dim = config.time_emb_dim
        self.time_emb_mode = config.time_emb_mode  # ['simple', 'sin']
        if self.time_emb_dim > 0:
            if self.time_emb_mode == 'simple':
                self.ligand_atom_emb = nn.Linear(ligand_atom_feature_dim + 1, emb_dim)
            elif self.time_emb_mode == 'sin':
                self.time_emb = nn.Sequential(
                    SinusoidalPosEmb(self.time_emb_dim),
                    nn.Linear(self.time_emb_dim, self.time_emb_dim * 4),
                    nn.GELU(),
                    nn.Linear(self.time_emb_dim * 4, self.time_emb_dim)
                )
                self.ligand_atom_emb = nn.Linear(ligand_atom_feature_dim + self.time_emb_dim, emb_dim)
            else:
                raise NotImplementedError
        else:
            self.ligand_atom_emb = nn.Linear(ligand_atom_feature_dim, emb_dim)

        # 实例化神经网络
        self.refine_net_type = config.model_type
        self.refine_net = get_refine_net(self.refine_net_type, config)
        # 推理 binding affinity 层
        self.dock_inference = nn.Sequential(
            nn.Linear(self.hidden_dim, self.hidden_dim),
            ShiftedSoftplus(),
            nn.Linear(self.hidden_dim, 1),
        )

        # TODO: show warning if drop_protein_in_guide not in config
        self.drop_protein_in_guide = config.get("drop_protein_in_guide", False)
        assert isinstance(self.drop_protein_in_guide, bool), f"`drop_protein_in_guide` should be True or False. Got: {self.drop_protein_in_guide}"

        # TODO: show warning if maximize_property not in config
        self.maximize_property = config.get("maximize_property", False)
        assert isinstance(self.maximize_property, bool), f"`maximize_property` should be True or False. Got: {self.maximize_property}"

        # 推理 binding affinity 层的输出类型
        self.problem_type = config.get("problem_type", "regression")
        assert self.problem_type in ["regression", "classification"], f"`problem_type` should be 'regression' or 'classification'. Got: {self.problem_type}"

5.2 get_loss

训练梯度引导模型中,DockGuideNet3D 的 get_loss 方法计算梯度引导模型的损失。与主模型的 get_diffusion_loss 很像。

(1)先是基本设置,

        # 批次中图数量
        num_graphs = batch_protein.max().item() + 1
        # 去坐标中心
        protein_pos, ligand_pos, _ = center_pos(
            protein_pos, ligand_pos, batch_protein, batch_ligand, mode=self.center_pos_mode)

        # 1. sample noise levels
        if time_step is None:
            # 采样时间 t 及其概率
            time_step, pt = self.sample_time(num_graphs, protein_pos.device, self.sample_time_method)
        else:
            # 时间 t 固定参数输入,只计算概率
            pt = torch.ones_like(time_step).float() / self.num_timesteps
        ## precomputed beforehand to save computational time. Here it is only indexed based on the time_step
        a = self.alphas_cumprod.index_select(0, time_step)  # (num_graphs, ) # α 累乘

(2)扰动小分子的坐标和节点类型;

        # 2. perturb pos and v
        a_pos = a[batch_ligand].unsqueeze(-1)  # (num_ligand_atoms, 1) # α累乘 扩展
        ## sampling a normal distribution to add as noise
        # 初始化/采样噪音
        pos_noise = torch.zeros_like(ligand_pos) 
        pos_noise.normal_()
        ## update the coordinates
        # Xt = a.sqrt() * X0 + (1-a).sqrt() * eps 扰动小分子的坐标
        ligand_pos_perturbed = a_pos.sqrt() * ligand_pos + (1.0 - a_pos).sqrt() * pos_noise  # pos_noise * std
        
        ## update the categories 扰动原子类型
        # Vt = a * V0 + (1-a) / K
        log_ligand_v0 = index_to_log_onehot(ligand_v, self.num_classes) # 转为 onr-hot 类型,然后取对数,即,生成节点类型的对数概率
        ## move some of the probablity mass to other indexes
        # 采样扰动后的节点类型索引,及其对数概率
        ligand_v_perturbed, log_ligand_vt = self.q_v_sample(log_ligand_v0, time_step, batch_ligand)

其中,index_to_log_onehot 和 self.q_v_sample 等一些支持函数与主模型的完全相同,这里就不再重复,把代码直接贴出来。节点类型和坐标在被扰动的时候,都遵循上述提及的公式:

x_{t} = \sqrt{\alpha_{cumprod, t}} x_{0} + \sqrt{1-\alpha_{cumprod, t}}\epsilon

index_to_log_onehot 和 self.q_v_sample 等支持函数的代码如下:

def index_to_log_onehot(x, num_classes):
    '''
     索引转化为对数概率
    '''
    assert x.max().item() < num_classes, f'Error: {x.max().item()} >= {num_classes}'
    x_onehot = F.one_hot(x, num_classes)
    # permute_order = (0, -1) + tuple(range(1, len(x.size())))
    # x_onehot = x_onehot.permute(permute_order)
    log_x = torch.log(x_onehot.float().clamp(min=1e-30))
    return log_x

    def q_v_sample(self, log_v0, t, batch):
        '''
        基于概率,采样节点类型
        返回 节点类型的索引及其对数概率
        '''
        # 扰动,生成 v_t (对数概率)
        log_qvt_v0 = self.q_v_pred(log_v0, t, batch) 
        # 基于概率,随机采样节点类型概率,返回节点类型索引
        sample_index = log_sample_categorical(log_qvt_v0) 
        # 将节点类型转化为 对数概率, 即:log(one-hot)
        log_sample = index_to_log_onehot(sample_index, self.num_classes)
        return sample_index, log_sample

    def q_v_pred(self, log_v0, t, batch):
        '''
        计算 q(V_t | V_0),扰动节点类型
        '''
        # compute q(vt | v0)
        log_cumprod_alpha_t = extract(self.log_alphas_cumprod_v, t, batch)
        log_1_min_cumprod_alpha = extract(self.log_one_minus_alphas_cumprod_v, t, batch)

        log_probs = log_add_exp(
            log_v0 + log_cumprod_alpha_t,
            log_1_min_cumprod_alpha - np.log(self.num_classes)
        )
        return log_probs

def log_sample_categorical(logits):
    uniform = torch.rand_like(logits) # 按照对数概率,采样节点类型的’随机性‘
    gumbel_noise = -torch.log(-torch.log(uniform + 1e-30) + 1e-30) # 随机性进行对数化
    # 随机性与概率合并,返回最大概率的节点类型(不是 one-hot)
    sample_index = (gumbel_noise + logits).argmax(dim=-1) 
    # sample_onehot = F.one_hot(sample, self.num_classes)
    # log_sample = index_to_log_onehot(sample, self.num_classes)
    return sample_index

(3)输入到 forword 函数中计算预测的 Binding Affinity
        # 3. forward-pass NN, feed perturbed pos and v, output noise
        # 将 蛋白坐标和节点类型,扰动后的小分子坐标和节点类型,以及时间步 t 输入到 forward 中
        # 输出预测的 binding affinity
        preds = self(
            protein_pos=protein_pos,
            protein_atom_feature=protein_v,
            batch_protein=batch_protein,

            ligand_pos=ligand_pos_perturbed,
            ligand_atom_feature=F.one_hot(ligand_v_perturbed, self.num_classes), # one-hot
            batch_ligand=batch_ligand,
            time_step=time_step,
            fix_x=True
        )

关于 forward 函数,主要包括:

(3.1) 时间步 t 的嵌入,然后与小分子节点类型 contact 合并成新的小分子节点类型;

        if self.time_emb_dim > 0:
            if self.time_emb_mode == 'simple':
                input_ligand_feat = torch.cat([
                    ligand_atom_feature,
                    (time_step / self.num_timesteps)[batch_ligand].unsqueeze(-1)
                ], -1)
            elif self.time_emb_mode == 'sin':
                time_feat = self.time_emb(time_step)
                time_feat = time_feat[batch_ligand]
                input_ligand_feat = torch.cat([ligand_atom_feature, time_feat], -1)
            else:
                raise NotImplementedError
        else:
            input_ligand_feat = ligand_atom_feature

(3.2) 小分子节点特征的标记与嵌入

        # 新的小分子节点类型 嵌入
        init_ligand_h = self.ligand_atom_emb(input_ligand_feat)
        # 小分子节点的标记
        if self.config.node_indicator:
            init_ligand_h = torch.cat([init_ligand_h, torch.ones(len(init_ligand_h), 1).to(init_ligand_h.device)], -1)

(3.3) 蛋白节点的标记合并蛋白和小分子的坐标以及节点类型

        if self.drop_protein_in_guide is True:
            # 如果 在梯度计算中不包含 蛋白,比如:预测logP
            h_all, pos_all, batch_all = init_ligand_h, ligand_pos, batch_ligand
            # TODO: check `mask_ligand`
            mask_ligand = torch.ones([batch_ligand.size(0)], device=batch_ligand.device).bool()
        else:
            # 在预测中需要包含蛋白

            # 蛋白节点嵌入
            h_protein = self.protein_atom_emb(protein_atom_feature)
            
            # 蛋白节点标记
            if self.config.node_indicator:
                h_protein = torch.cat([h_protein, torch.zeros(len(h_protein), 1).to(h_protein.device)], -1)

            # 合并蛋白和小分子节点类型和坐标
            h_all, pos_all, batch_all, mask_ligand = compose_context(
                h_protein=h_protein,
                h_ligand=init_ligand_h,
                pos_protein=protein_pos,
                pos_ligand=ligand_pos,
                batch_protein=batch_protein,
                batch_ligand=batch_ligand,
            )

关于 compose_context 函数,也与主函数完全一致,如下:

def compose_context(h_protein, h_ligand, pos_protein, pos_ligand, batch_protein, batch_ligand):
    '''
    蛋白质 (protein) 和配体 (ligand) 的嵌入信息、位置和批次信息进行组合,并返回合并后的上下文表示
    '''
    # previous version has problems when ligand atom types are fixed
    # (due to sorting randomly in case of same element)

    batch_ctx = torch.cat([batch_protein, batch_ligand], dim=0) # mask 批次信息 图的序号
    # sort_idx = batch_ctx.argsort() 按照批次信息进行排序的索引
    sort_idx = torch.sort(batch_ctx, stable=True).indices
    
    # mask 用于区分蛋白和小分子的掩码,按照图的序号重拍后
    mask_ligand = torch.cat([
        torch.zeros([batch_protein.size(0)], device=batch_protein.device).bool(),
        torch.ones([batch_ligand.size(0)], device=batch_ligand.device).bool(),
    ], dim=0)[sort_idx]

    batch_ctx = batch_ctx[sort_idx] # 重拍后的批次信息
    # 重拍后的节点特征
    h_ctx = torch.cat([h_protein, h_ligand], dim=0)[sort_idx]  # (N_protein+N_ligand, H) 
    # 重拍后的坐标
    pos_ctx = torch.cat([pos_protein, pos_ligand], dim=0)[sort_idx]  # (N_protein+N_ligand, 3)

    return h_ctx, pos_ctx, batch_ctx, mask_ligand

(3.4) 使用 EGNN 以及推理 binding affinity 层,预测并返回 Binding Affinity 的值

        ## get the hidden states from GNN EGNN 计算预测值(维度是:(atom_number, ))
        outputs = self.refine_net(h_all, pos_all, mask_ligand, batch_all, return_all=return_all, fix_x=fix_x)
        if fix_x:
            assert torch.all(torch.eq(outputs['x'], pos_all))
        ## aggregate/pool the hidden states 整合输出 (graph_num,)
        aggregate_output = scatter(outputs['h'], index=batch_all, dim=0, reduce='sum')
        ## get the binding energy (graph_num,)
        # 推理 binding affinity 层
        output = self.dock_inference(aggregate_output) 
        return output

完整的 forward hanshu 如下:

    def forward(self, protein_pos, protein_atom_feature, ligand_pos, ligand_atom_feature, batch_protein, batch_ligand,
                time_step, return_all=False, fix_x=False):
        '''
        基于 蛋白的坐标和节点类型、扰动后的小分子坐标和节点类型、时间步t,
        以及批次mask(batch_protein, batch_ligand), 
        输出 预测的 binding affinity(也可以是其他)
        '''
        batch_size = batch_protein.max().item() + 1
        # time embedding
        ## added to inform the model about the extent of noise that has been added
        # 时间步 t 的嵌入,然后与小分子节点类型 contact 合并成新的小分子节点类型
        if self.time_emb_dim > 0:
            if self.time_emb_mode == 'simple':
                input_ligand_feat = torch.cat([
                    ligand_atom_feature,
                    (time_step / self.num_timesteps)[batch_ligand].unsqueeze(-1)
                ], -1)
            elif self.time_emb_mode == 'sin':
                time_feat = self.time_emb(time_step)
                time_feat = time_feat[batch_ligand]
                input_ligand_feat = torch.cat([ligand_atom_feature, time_feat], -1)
            else:
                raise NotImplementedError
        else:
            input_ligand_feat = ligand_atom_feature

        # 新的小分子节点类型 嵌入
        init_ligand_h = self.ligand_atom_emb(input_ligand_feat)
        # 小分子节点的标记
        if self.config.node_indicator:
            init_ligand_h = torch.cat([init_ligand_h, torch.ones(len(init_ligand_h), 1).to(init_ligand_h.device)], -1)
        
        if self.drop_protein_in_guide is True:
            # 如果 在梯度计算中不包含 蛋白,比如:预测logP
            h_all, pos_all, batch_all = init_ligand_h, ligand_pos, batch_ligand
            # TODO: check `mask_ligand`
            mask_ligand = torch.ones([batch_ligand.size(0)], device=batch_ligand.device).bool()
        else:
            # 在预测中需要包含蛋白

            # 蛋白节点嵌入
            h_protein = self.protein_atom_emb(protein_atom_feature)
            
            # 蛋白节点标记
            if self.config.node_indicator:
                h_protein = torch.cat([h_protein, torch.zeros(len(h_protein), 1).to(h_protein.device)], -1)

            # 合并蛋白和小分子节点类型和坐标
            h_all, pos_all, batch_all, mask_ligand = compose_context(
                h_protein=h_protein,
                h_ligand=init_ligand_h,
                pos_protein=protein_pos,
                pos_ligand=ligand_pos,
                batch_protein=batch_protein,
                batch_ligand=batch_ligand,
            )
        ## get the hidden states from GNN EGNN 计算预测值(维度是:(atom_number, ))
        outputs = self.refine_net(h_all, pos_all, mask_ligand, batch_all, return_all=return_all, fix_x=fix_x)
        if fix_x:
            assert torch.all(torch.eq(outputs['x'], pos_all))
        ## aggregate/pool the hidden states 整合输出 (graph_num,)
        aggregate_output = scatter(outputs['h'], index=batch_all, dim=0, reduce='sum')
        ## get the binding energy (graph_num,)
        # 推理 binding affinity 层
        output = self.dock_inference(aggregate_output) 
        return output


(4)随后计算预测的 Binding Affinity 与真实值 dock 之间的损失;
 
        if self.problem_type == "regression":
            # 如果梯度引导模型是 回归模式
            loss_func = nn.MSELoss()
            loss = loss_func(preds.view(-1), dock)
        elif self.problem_type == "classification":
            # 分类模式
            loss_func = nn.BCEWithLogitsLoss()
            loss = loss_func(preds.view(-1), dock.float())
        else:
            raise ValueError(f"Unknown problem type: {self.problem_type}")
        

(5)返回损失及其预测值。

        # 返回损失和预测值
        if return_pred:
            return loss, preds
        else:
            return loss

get_loss 的完整代码如下:

    def get_loss(
            self, protein_pos, protein_v, batch_protein, ligand_pos, ligand_v, batch_ligand, dock, time_step=None, return_pred=False
    ):
        # 批次中图数量
        num_graphs = batch_protein.max().item() + 1
        # 去坐标中心
        protein_pos, ligand_pos, _ = center_pos(
            protein_pos, ligand_pos, batch_protein, batch_ligand, mode=self.center_pos_mode)

        # 1. sample noise levels
        if time_step is None:
            # 采样时间 t 及其概率
            time_step, pt = self.sample_time(num_graphs, protein_pos.device, self.sample_time_method)
        else:
            # 时间 t 固定参数输入,只计算概率
            pt = torch.ones_like(time_step).float() / self.num_timesteps
        ## precomputed beforehand to save computational time. Here it is only indexed based on the time_step
        a = self.alphas_cumprod.index_select(0, time_step)  # (num_graphs, ) # α 累乘

        # 2. perturb pos and v
        a_pos = a[batch_ligand].unsqueeze(-1)  # (num_ligand_atoms, 1) # α累乘 扩展
        ## sampling a normal distribution to add as noise
        # 初始化/采样噪音
        pos_noise = torch.zeros_like(ligand_pos) 
        pos_noise.normal_()
        ## update the coordinates
        # Xt = a.sqrt() * X0 + (1-a).sqrt() * eps 扰动小分子的坐标
        ligand_pos_perturbed = a_pos.sqrt() * ligand_pos + (1.0 - a_pos).sqrt() * pos_noise  # pos_noise * std
        
        ## update the categories 扰动原子类型
        # Vt = a * V0 + (1-a) / K
        log_ligand_v0 = index_to_log_onehot(ligand_v, self.num_classes) # 转为 onr-hot 类型,然后取对数,即,生成节点类型的对数概率
        ## move some of the probablity mass to other indexes
        # 采样扰动后的节点类型索引,及其对数概率
        ligand_v_perturbed, log_ligand_vt = self.q_v_sample(log_ligand_v0, time_step, batch_ligand)

        # 3. forward-pass NN, feed perturbed pos and v, output noise
        # 将 蛋白坐标和节点类型,扰动后的小分子坐标和节点类型,以及时间步 t 输入到 forward 中
        # 输出预测的 binding affinity
        preds = self(
            protein_pos=protein_pos,
            protein_atom_feature=protein_v,
            batch_protein=batch_protein,

            ligand_pos=ligand_pos_perturbed,
            ligand_atom_feature=F.one_hot(ligand_v_perturbed, self.num_classes), # one-hot
            batch_ligand=batch_ligand,
            time_step=time_step,
            fix_x=True
        )

        
        if self.problem_type == "regression":
            # 如果梯度引导模型是 回归模式
            loss_func = nn.MSELoss()
            loss = loss_func(preds.view(-1), dock)
        elif self.problem_type == "classification":
            # 分类模式
            loss_func = nn.BCEWithLogitsLoss()
            loss = loss_func(preds.view(-1), dock.float())
        else:
            raise ValueError(f"Unknown problem type: {self.problem_type}")
        
        # 返回损失和预测值
        if return_pred:
            return loss, preds
        else:
            return loss

get_loss 函数与第三部分组成了一个完整的梯度引导模型的训练过程,至此,训练一个梯度引导模型相关的部分就全部完成。

5.3 get_gradients_guide 

get_gradients_guide 是 DockGuideNet3D 非常重要的一个函数,用于在分子生成的每一个时间 t,小分子的 v_t 和 x_t 组成的分子,与目标属性之间的差距。  在分子生成的 2.3 梯度引导的分子分生成中,调用了 DockGuideNet3D 的 get_gradients_guide 方法来计算梯度,在主模型的 sample_multi_guided_diffusion 方法中被引用计算梯度,返回目标值(Bindding Affinity)对小分子坐标的梯度。

(1)设置评估模式,梯度清零

         self.eval() # 模型评估模式 (不启用 dropout 和 batch normalization 等训练时特有的操作)
        self.zero_grad() # 梯度清零 

(2)启用梯度计算,将坐标和节点设置为需要梯度,调用 forward 函数预测对应的 Binding Affinity 结果,对预测值进行剪裁,并设置预测值是越大越好(self.maximize_property=True)还是越小越好(self.maximize_property=False),计算坐标和节点类型的梯度,返回坐标梯度。(文章中未涉及节点类型的梯度)

      # 启动梯度计算
        with torch.enable_grad(): ## needed during inference
            ligand_pos = ligand_pos.detach().requires_grad_(True) # 设置坐标梯度
            if not pos_only:# 设置节点类型梯度
                ligand_atom_feature = ligand_atom_feature.detach().requires_grad_(True)
            # 调用 forward 函数计算预测的 binding affinity
            pred = self(
                protein_pos=protein_pos,
                protein_atom_feature=protein_atom_feature.float(),
                ligand_pos=ligand_pos,
                ligand_atom_feature=ligand_atom_feature.float(),
                batch_protein=batch_protein,
                batch_ligand=batch_ligand,
                time_step=time_step,
                fix_x=True
            )
            # 预测结果剪裁
            if self.problem_type not in ("classification", ):
                if clamp_pred_min is not None or clamp_pred_max is not None:
                    pred = torch.clamp(pred, min=clamp_pred_min, max=clamp_pred_max)
            else:
                raise NotImplementedError(f"Not implemented for {self.problem_type} problem type")
            
            # 如果优化目标是值越大越好
            if self.maximize_property:
                # maximize pred => minimize -pred`
                pred = -pred
            # pred.mean().backward()
            ## must reduce the binding energies (Vina Dock) - take the mean of pred across the batch 
            ## "ligand_pos_grad" is in increasing direction
            ## must subtract it at the end
            # 计算梯度
            ligand_pos_grad = torch.autograd.grad(pred.sum(), ligand_pos, retain_graph=True)[0]  # 坐标梯度
            if not pos_only: # 节点类型梯度
                ligand_atom_feature_grad = torch.autograd.grad(pred.sum(), ligand_atom_feature, retain_graph=True)[0]
                return ligand_pos_grad, ligand_atom_feature_grad
            # 返回坐标梯度
            return ligand_pos_grad

get_gradients_guide 的 完整代码如下:

  def get_gradients_guide(self, protein_pos, protein_atom_feature, ligand_pos, ligand_atom_feature, 
        batch_protein, batch_ligand, time_step, 
        pos_only=False, 
        clamp_pred_min=None, clamp_pred_max=None # 预测值的最大值和最小值剪裁
        ):
        ## get the gradients w.r.t ligand position and features using the Binding-affinity EGNN predictor
        self.eval() # 模型评估模式 (不启用 dropout 和 batch normalization 等训练时特有的操作)
        self.zero_grad() # 梯度清零 
        # 启动梯度计算
        with torch.enable_grad(): ## needed during inference
            ligand_pos = ligand_pos.detach().requires_grad_(True) # 设置坐标梯度
            if not pos_only:# 设置节点类型梯度
                ligand_atom_feature = ligand_atom_feature.detach().requires_grad_(True)
            # 调用 forward 函数计算预测的 binding affinity
            pred = self(
                protein_pos=protein_pos,
                protein_atom_feature=protein_atom_feature.float(),
                ligand_pos=ligand_pos,
                ligand_atom_feature=ligand_atom_feature.float(),
                batch_protein=batch_protein,
                batch_ligand=batch_ligand,
                time_step=time_step,
                fix_x=True
            )
            # 预测结果剪裁
            if self.problem_type not in ("classification", ):
                if clamp_pred_min is not None or clamp_pred_max is not None:
                    pred = torch.clamp(pred, min=clamp_pred_min, max=clamp_pred_max)
            else:
                raise NotImplementedError(f"Not implemented for {self.problem_type} problem type")
            
            # 如果优化目标是值越大越好
            if self.maximize_property:
                # maximize pred => minimize -pred`
                pred = -pred
            # pred.mean().backward()
            ## must reduce the binding energies (Vina Dock) - take the mean of pred across the batch 
            ## "ligand_pos_grad" is in increasing direction
            ## must subtract it at the end
            # 计算梯度
            ligand_pos_grad = torch.autograd.grad(pred.sum(), ligand_pos, retain_graph=True)[0]  # 坐标梯度
            if not pos_only: # 节点类型梯度
                ligand_atom_feature_grad = torch.autograd.grad(pred.sum(), ligand_atom_feature, retain_graph=True)[0]
                return ligand_pos_grad, ligand_atom_feature_grad
            # 返回坐标梯度
            return ligand_pos_grad

六、 ScorePosNet3D 中的等变神经网络

主模型 ScorePosNet3D 中可以选择两种等变网络: EGNN 和 uni_o2,作者在文章中和代码中使用的都是 EGNN。 梯度引导模型 DockGuideNet3D 使用的也是 EGNN。

关于 EGNN 的介绍,可以参考之前的博客:药物设计中的SE3等变图神经网络层- EGNN 代码解析_se3等变神经网络-CSDN博客文章浏览阅读1.8k次,点赞13次,收藏37次。此部分内容介绍了常用在药物设计深度学习中的SE3等变网络层 EGNN。主要对EGNN的代码逻辑、模块进行解析,并介绍其中的SE3等变在模型中的原理。_se3等变神经网络 https://blog.csdn.net/wufeil7/article/details/139456373?spm=1001.2014.3001.5501

关于 uni_o2 作者提供了代码,但是没有相应的介绍,不知道是否符合 SE3等变。因此这里暂时不介绍。

七、总结

TagMol 是一个梯度引导的分子生成扩散模型,作者提供了完整的项目代码。

通过对这个项目的代码详细解析,可以仔细理解分子生成的扩散模型的原理及其代码架构。

以 TagMol 作为基础模型,可以很快的理解其他基于扩散模型的分子生成模型,可以快速窥探每个模型的特点。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

DrugAutoPilot

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值