multitask-graph-attention(一)

多任务图注意力框架预测药物毒性,原文:Mining Toxicity Information from Large Amounts of Toxicity Data,代码:MGA,文章从 MGA/interpretation/Ames_interpret.ipynb 开始
请添加图片描述


1.build_dataset

args['data_name'] = 'toxicity'  # change

args['bin_path'] = '../data/' + args['data_name'] + '.bin'
args['group_path'] = '../data/' + args['data_name'] + '_group.csv'

args['select_task_list'] = ['Carcinogenicity', 'Ames Mutagenicity', 'Respiratory toxicity',
                             'Eye irritation', 'Eye corrosion', 'Cardiotoxicity1', 'Cardiotoxicity5',
                             'Cardiotoxicity10', 'Cardiotoxicity30',
                             'CYP1A2', 'CYP2C19', 'CYP2C9', 'CYP2D6', 'CYP3A4',
                             'NR-AR', 'NR-AR-LBD', 'NR-AhR', 'NR-Aromatase', 'NR-ER', 'NR-ER-LBD',
                             'NR-PPAR-gamma', 'SR-ARE', 'SR-ATAD5', 'SR-HSE', 'SR-MMP', 'SR-p53',
                             'Acute oral toxicity (LD50)', 'LC50DM', 'BCF', 'LC50', 'IGC50'] # change
args['all_task_list'] =['Carcinogenicity', 'Ames Mutagenicity', 'Respiratory toxicity',
                         'Eye irritation', 'Eye corrosion', 'Cardiotoxicity1', 'Cardiotoxicity5',
                         'Cardiotoxicity10', 'Cardiotoxicity30',
                         'CYP1A2', 'CYP2C19', 'CYP2C9', 'CYP2D6', 'CYP3A4',
                         'NR-AR', 'NR-AR-LBD', 'NR-AhR', 'NR-Aromatase', 'NR-ER', 'NR-ER-LBD',
                         'NR-PPAR-gamma', 'SR-ARE', 'SR-ATAD5', 'SR-HSE', 'SR-MMP', 'SR-p53',
                         'Acute oral toxicity (LD50)', 'LC50DM', 'BCF', 'LC50', 'IGC50']# change
args['select_task_index'] = []

for index, task in enumerate(args['all_task_list']):
    if task in args['select_task_list']:
        args['select_task_index'].append(index)

train_set, val_set, test_set, task_number = build_dataset.load_graph_from_csv_bin_for_splited(
        bin_path=args['bin_path'],
        group_path=args['group_path'],
        select_task_index=args['select_task_index']
)
  • 根据参数构建数据集,这里 args[‘all_task_list’] 和 args[‘select_task_list’] 完全一致
args['bin_path'],args['group_path'],args['select_task_index']
"""
('../data/toxicity.bin',
 '../data/toxicity_group.csv',
 [0,
  1,
  2,
  ...
  30])
  """

1.1.load_graph_from_csv_bin_for_splited

def load_graph_from_csv_bin_for_splited(
        bin_path='example.bin',
        group_path='example.csv',
        select_task_index=None):
    smiles = pd.read_csv(group_path, index_col=None).smiles.values
    group = pd.read_csv(group_path, index_col=None).group.to_list()
    
    graphs, detailed_information = load_graphs(bin_path)
    
    labels = detailed_information['labels']
    mask = detailed_information['mask']
    
    if select_task_index is not None:
        labels = labels[:, select_task_index]
        mask = mask[:, select_task_index]
    
    # calculate not_use index
    notuse_mask = torch.mean(mask.float(), 1).numpy().tolist()
    not_use_index = []
    for index, notuse in enumerate(notuse_mask):
        if notuse==0:
            not_use_index.append(index)
    train_index=[]
    val_index = []
    test_index = []
    for index, group_index in enumerate(group):
        if group_index=='training' and index not in not_use_index:
            train_index.append(index)
        if group_index=='valid' and index not in not_use_index:
            val_index.append(index)
        if group_index == 'test' and index not in not_use_index:
            test_index.append(index)
    graph_List = []
    for g in graphs:
        graph_List.append(g)
    graphs_np = np.array(graphs)
    train_smiles, val_smiles, test_smiles = split_dataset_according_index(smiles, train_index, val_index, test_index)
    train_labels, val_labels, test_labels = split_dataset_according_index(labels.numpy(), train_index, val_index,test_index, data_type='pd')
    train_mask, val_mask, test_mask = split_dataset_according_index(mask.numpy(), train_index, val_index, test_index,data_type='pd')
    train_graph, val_graph, test_graph = split_dataset_according_index(graphs_np, train_index, val_index, test_index)

    # delete the 0_pos_label and 0_neg_label
    task_number = train_labels.values.shape[1]

    train_set = []
    val_set = []
    test_set = []
    for i in range(len(train_index)):
        molecule = [train_smiles[i], train_graph[i], train_labels.values[i], train_mask.values[i]]
        train_set.append(molecule)

    for i in range(len(val_index)):
        molecule = [val_smiles[i], val_graph[i], val_labels.values[i], val_mask.values[i]]
        val_set.append(molecule)

    for i in range(len(test_index)):
        molecule = [test_smiles[i], test_graph[i], test_labels.values[i], test_mask.values[i]]
        test_set.append(molecule)
    print(len(train_set), len(val_set), len(test_set),  task_number)
    return train_set, val_set, test_set, task_number

将源代码中的数据来源改成 Hepatotoxicity, select_task_index 设为 None 可以运行,打印输出:

train_set[:3],val_set[:3],test_set[:3],task_number
"""
([['O=[N+]([O-])c1cccc([N+](=O)[O-])c1', Graph(num_nodes=12, num_edges=24,
         ndata_schemes={'atom': Scheme(shape=(40,), dtype=torch.int64)}
         edata_schemes={'normal': Scheme(shape=(), dtype=torch.float32), 'etype': Scheme(shape=(), dtype=torch.int64)}), array([1]), array([1])],
  ['C#C[C@]1(O)CC[C@H]2[C@@H]3CCc4cc(OS(=O)(=O)O)ccc4[C@H]3CC[C@@]21C',
   Graph(num_nodes=26, num_edges=58,
         ndata_schemes={'atom': Scheme(shape=(40,), dtype=torch.int64)}
         edata_schemes={'normal': Scheme(shape=(), dtype=torch.float32), 'etype': Scheme(shape=(), dtype=torch.int64)}),
   array([1]),
   array([1])],
  ['C=C1C(=CC=C2CCCC3(C)C2CCC3C(C)CCCC(C)C)CC(O)CC1O',
   Graph(num_nodes=29, num_edges=62,
         ndata_schemes={'atom': Scheme(shape=(40,), dtype=torch.int64)}
         edata_schemes={'normal': Scheme(shape=(), dtype=torch.float32), 'etype': Scheme(shape=(), dtype=torch.int64)}),
   array([1]),
   array([1])]],
 [['ClC1C(Cl)C(Cl)C(Cl)C(Cl)C1Cl', Graph(num_nodes=12, num_edges=24,
         ndata_schemes={'atom': Scheme(shape=(40,), dtype=torch.int64)}
         edata_schemes={'normal': Scheme(shape=(), dtype=torch.float32), 'etype': Scheme(shape=(), dtype=torch.int64)}), array([1]), array([1])],
  ['Cc1c([N+](=O)[O-])cc([N+](=O)[O-])cc1[N+](=O)[O-]',
   Graph(num_nodes=16, num_edges=32,
         ndata_schemes={'atom': Scheme(shape=(40,), dtype=torch.int64)}
         edata_schemes={'normal': Scheme(shape=(), dtype=torch.float32), 'etype': Scheme(shape=(), dtype=torch.int64)}),
   array([1]),
   array([1])],
  ['Nc1nc(NC2CC2)c2ncn([C@H]3C=C[C@@H](CO)C3)c2n1',
   Graph(num_nodes=21, num_edges=48,
         ndata_schemes={'atom': Scheme(shape=(40,), dtype=torch.int64)}
         edata_schemes={'normal': Scheme(shape=(), dtype=torch.float32), 'etype': Scheme(shape=(), dtype=torch.int64)}),
   array([1]),
   array([1])]],
 [['C[C@H](N)C(=O)N[C@@H](C)C(=O)NC1[C@@H]2CN(c3nc4c(cc3F)c(=O)c(C(=O)O)cn4-c3ccc(F)cc3F)C[C@H]12',
   Graph(num_nodes=40, num_edges=88,
         ndata_schemes={'atom': Scheme(shape=(40,), dtype=torch.int64)}
         edata_schemes={'normal': Scheme(shape=(), dtype=torch.float32), 'etype': Scheme(shape=(), dtype=torch.int64)}),
   array([1]),
   array([1])],
  ['C=CCOc1ccc(CC(=O)O)cc1Cl', Graph(num_nodes=15, num_edges=30,
         ndata_schemes={'atom': Scheme(shape=(40,), dtype=torch.int64)}
         edata_schemes={'normal': Scheme(shape=(), dtype=torch.float32), 'etype': Scheme(shape=(), dtype=torch.int64)}), array([1]), array([1])],
  ['O=C(O)CCCCCCNC1c2ccccc2CCc2ccccc21', Graph(num_nodes=25, num_edges=54,
         ndata_schemes={'atom': Scheme(shape=(40,), dtype=torch.int64)}
         edata_schemes={'normal': Scheme(shape=(), dtype=torch.float32), 'etype': Scheme(shape=(), dtype=torch.int64)}), array([1]), array([1])]],
 1)
"""
  • 这里每个数据点有 smiles,graph,label 和 mask,根据 2.3.build_mask 的分析,这里 mask 值为 0 表示 label 为 na,即数据无效

1.2.split_dataset_according_index

def split_dataset_according_index(dataset, train_index, val_index, test_index, data_type='np'):
    if data_type =='pd':
        return pd.DataFrame(dataset[train_index]), pd.DataFrame(dataset[val_index]), pd.DataFrame(dataset[test_index])
    if data_type =='np':
        return dataset[train_index], dataset[val_index], dataset[test_index]
  • 构建模型所用的数据构造方法应该是下面的函数

2.built_data_and_save_for_splited

def built_data_and_save_for_splited(
        origin_path='example.csv',
        save_path='example.bin',
        group_path='example_group.csv',
        task_list_selected=None,
         ):
    '''
        origin_path: str
            origin csv data set path, including molecule name, smiles, task
        save_path: str
            graph out put path
        group_path: str
            group out put path
        task_list_selected: list
            a list of selected task
        '''
    data_origin = pd.read_csv(origin_path)
    smiles_name = 'smiles'
    data_origin = data_origin.fillna(123456)
    labels_list = [x for x in data_origin.columns if x not in ['smiles', 'group']]
    if task_list_selected is not None:
        labels_list = task_list_selected
    data_set_gnn = multi_task_build_dataset(dataset_smiles=data_origin, labels_list=labels_list, smiles_name=smiles_name)

    smiles, graphs, labels, mask, split_index = map(list, zip(*data_set_gnn))
    graph_labels = {'labels': torch.tensor(labels),
                    'mask': torch.tensor(mask)
                    }
    split_index_pd = pd.DataFrame(columns=['smiles', 'group'])
    split_index_pd.smiles = smiles
    split_index_pd.group = split_index
    split_index_pd.to_csv(group_path, index=None, columns=None)
    print('Molecules graph is saved!')
    save_graphs(save_path, graphs, graph_labels)

  • 根据原始的 csv 文件构造分子图,输出 _group.csv 和 .bin 文件
  • 原始的 csv 文件应该是有多列,其中两列分别是 smiles 和 group,其他列是 task 的名字
  • map(list, zip(*data_set_gnn)) 的效果可以类比下面的情况,相当于把每个分子的 smiles,graphs 等分别汇总存储
dataset=[
    [1,2,3,4,5], #mol1
    [6,7,8,9,10]  #mol2
]
list(map(list,zip(*dataset)))
"""
[[1, 6], [2, 7], [3, 8], [4, 9], [5, 10]]
"""

2.1.multi_task_build_dataset

def multi_task_build_dataset(dataset_smiles, labels_list, smiles_name):
    dataset_gnn = []
    failed_molecule = []
    labels = dataset_smiles[labels_list]
    split_index = dataset_smiles['group']
    smilesList = dataset_smiles[smiles_name]
    molecule_number = len(smilesList)
    for i, smiles in enumerate(smilesList):
        try:
            g = construct_RGCN_bigraph_from_smiles(smiles)
            mask = build_mask(labels.loc[i], mask_value=123456)
            molecule = [smiles, g, labels.loc[i], mask, split_index.loc[i]]
            dataset_gnn.append(molecule)
            print('{}/{} molecule is transformed!'.format(i+1, molecule_number))
        except:
            print('{} is transformed failed!'.format(smiles))
            molecule_number = molecule_number - 1
            failed_molecule.append(smiles)
    print('{}({}) is transformed failed!'.format(failed_molecule, len(failed_molecule)))
    return dataset_gnn
  • 编码 smiles 为分子图并构建 mask 后返回数据集,shape 应该是 (molecule_num,5)
  • 如果有多个 task,labels 应该是一个分子对应的多个任务的标签,依据 origin_path 文件进行分类

2.2.construct_RGCN_bigraph_from_smiles

def construct_RGCN_bigraph_from_smiles(smiles):
    g = DGLGraph()

    # Add nodes
    mol = MolFromSmiles(smiles)
    num_atoms = mol.GetNumAtoms()
    g.add_nodes(num_atoms)
    atoms_feature_all = []
    for atom_index, atom in enumerate(mol.GetAtoms()):
        atom_feature = atom_features(atom).tolist()
        atoms_feature_all.append(atom_feature)
    g.ndata["atom"] = torch.tensor(atoms_feature_all)

    # Add edges
    src_list = []
    dst_list = []
    etype_feature_all = []
    num_bonds = mol.GetNumBonds()
    for i in range(num_bonds):
        bond = mol.GetBondWithIdx(i)
        etype_feature = etype_features(bond)
        u = bond.GetBeginAtomIdx()
        v = bond.GetEndAtomIdx()
        src_list.extend([u, v])
        dst_list.extend([v, u])
        etype_feature_all.append(etype_feature)
        etype_feature_all.append(etype_feature)

    g.add_edges(src_list, dst_list)
    normal_all = []
    for i in etype_feature_all:
        normal = etype_feature_all.count(i)/len(etype_feature_all)
        normal = round(normal, 1)
        normal_all.append(normal)

    g.edata["etype"] = torch.tensor(etype_feature_all)
    g.edata["normal"] = torch.tensor(normal_all)
    return g
  • 利用 rdkit 中的 atom 对象编码原子全局特征。g.ndata[“atom”] 是原子属性的独热编码,因此一个分子 smiles 转化为分子图后 atom 特征的 shape 是 (atom_num, n),这里的 n 是 40
  • 利用 rdkit 中的 bond 对象编码化学键全局特征。g.edata[“etype”] 是化学键属性的数值编码,g.edata[“normal”] 是 g.edata[“etype”] 的统计数据
  • construct_RGCN_bigraph_from_smiles 中 RGCN 体现在哪里还不明朗

2.2.1.atom_features

def atom_features(atom, explicit_H = False, use_chirality=True):
    results = one_of_k_encoding_unk(
      atom.GetSymbol(),
      [
        'B',
        'C',
        'N',
        'O',
        'F',
        'Si',
        'P',
        'S',
        'Cl',
        'As',
        'Se',
        'Br',
        'Te',
        'I',
        'At',
        'other'
      ]) + one_of_k_encoding(atom.GetDegree(),[0, 1, 2, 3, 4, 5, 6]) +
              [atom.GetFormalCharge(), atom.GetNumRadicalElectrons()] +
              one_of_k_encoding_unk(atom.GetHybridization(), [
                Chem.rdchem.HybridizationType.SP, Chem.rdchem.HybridizationType.SP2,
                Chem.rdchem.HybridizationType.SP3, Chem.rdchem.HybridizationType.SP3D,
                Chem.rdchem.HybridizationType.SP3D2,'other']) + [atom.GetIsAromatic()]
                # [atom.GetIsAromatic()] # set all aromaticity feature blank.
    # In case of explicit hydrogen(QM8, QM9), avoid calling `GetTotalNumHs`
    if not explicit_H:
        results = results + one_of_k_encoding_unk(atom.GetTotalNumHs(),
                                                  [0, 1, 2, 3, 4])
    if use_chirality:
        try:
            results = results + one_of_k_encoding_unk(
                atom.GetProp('_CIPCode'),
                ['R', 'S']) + [atom.HasProp('_ChiralityPossible')]
        except:
            results = results + [False, False
                                 ] + [atom.HasProp('_ChiralityPossible')]

    return np.array(results)

atom.GetSymbol() 获取原子的元素符号,atom.GetDegree() 获取原子连接的键,对原子的一些全局属性进行 one-hot 编码

2.2.2.one_of_k_encoding_unk

def one_of_k_encoding_unk(x, allowable_set):
    """Maps inputs not in the allowable set to the last element."""
    if x not in allowable_set:
        x = allowable_set[-1]
    return [x == s for s in allowable_set]
  • 根据 allowable_set 进行 one-hot 编码,如果不在 allowable_set 最后一个编码点会被设为 1

2.2.3.one_of_k_encoding

def one_of_k_encoding(x, allowable_set):
    if x not in allowable_set:
        raise Exception("input {0} not in allowable set{1}:".format(
            x, allowable_set))
    return [x == s for s in allowable_set]
  • 根据 allowable_set 进行 one-hot 编码,如果不在 allowable_set 会报错

2.2.4.etype_features

def etype_features(bond, use_chirality=True, atompair=True):
    bt = bond.GetBondType()
    bond_feats_1 = [
        bt == Chem.rdchem.BondType.SINGLE, bt == Chem.rdchem.BondType.DOUBLE,
        bt == Chem.rdchem.BondType.TRIPLE, bt == Chem.rdchem.BondType.AROMATIC,
    ]
    for i, m in enumerate(bond_feats_1):
        if m == True:
            a = i

    bond_feats_2 = bond.GetIsConjugated()
    if bond_feats_2 == True:
        b = 1
    else:
        b = 0

    bond_feats_3 = bond.IsInRing
    if bond_feats_3 == True:
        c = 1
    else:
        c = 0

    index = a * 1 + b * 4 + c * 8
    if use_chirality:
        bond_feats_4 = one_of_k_encoding_unk(
            str(bond.GetStereo()),
            ["STEREONONE", "STEREOANY", "STEREOZ", "STEREOE"])
        for i, m in enumerate(bond_feats_4):
            if m == True:
                d = i
        index = index + d * 16
    if atompair == True:
        atom_pair_str = bond.GetBeginAtom().GetSymbol() + bond.GetEndAtom().GetSymbol()
        bond_feats_5 = one_of_k_atompair_encoding(
            atom_pair_str, [['CC'], ['CN', 'NC'], ['ON', 'NO'], ['CO', 'OC'], ['CS', 'SC'],
                            ['SO', 'OS'], ['NN'], ['SN', 'NS'], ['CCl', 'ClC'], ['CF', 'FC'],
                            ['CBr', 'BrC'], ['others']]
        )
        for i, m in enumerate(bond_feats_5):
            if m == True:
                e = i
        index = index + e*64
    return index

以十进制编码数字化学键特征,index 是二进制数字 e0dcb0a 转化为 十进制后的值,即 i n d e x = a × 2 0 + 0 × 2 1 + b × 2 2 + c × 2 3 + d × 2 4 + 0 × 2 5 + e × 2 6 index=a\times 2^0+0\times 2^1+b\times 2^2+c\times 2^3+d\times 2^4+0\times 2^5+e\times 2^6 index=a×20+0×21+b×22+c×23+d×24+0×25+e×26

2.3.build_mask

def build_mask(labels_list, mask_value=100):
    mask = []
    for i in labels_list:
        if i==mask_value:
            mask.append(0)
        else:
            mask.append(1)
    return mask

  • 如果 label 的值无效,mask 就是 0,有效的话 mask 的值是 1。
  • 之前在 built_data_and_save_for_splited 函数中 进行了 data_origin.fillna(123456),而这里 mask = build_mask(labels.loc[i], mask_value=123456)
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

_森罗万象

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值