多任务图注意力框架预测药物毒性,原文:Mining Toxicity Information from Large Amounts of Toxicity Data,代码:MGA,文章从 MGA/interpretation/Ames_interpret.ipynb 开始
文章目录
1.build_dataset
args['data_name'] = 'toxicity' # change
args['bin_path'] = '../data/' + args['data_name'] + '.bin'
args['group_path'] = '../data/' + args['data_name'] + '_group.csv'
args['select_task_list'] = ['Carcinogenicity', 'Ames Mutagenicity', 'Respiratory toxicity',
'Eye irritation', 'Eye corrosion', 'Cardiotoxicity1', 'Cardiotoxicity5',
'Cardiotoxicity10', 'Cardiotoxicity30',
'CYP1A2', 'CYP2C19', 'CYP2C9', 'CYP2D6', 'CYP3A4',
'NR-AR', 'NR-AR-LBD', 'NR-AhR', 'NR-Aromatase', 'NR-ER', 'NR-ER-LBD',
'NR-PPAR-gamma', 'SR-ARE', 'SR-ATAD5', 'SR-HSE', 'SR-MMP', 'SR-p53',
'Acute oral toxicity (LD50)', 'LC50DM', 'BCF', 'LC50', 'IGC50'] # change
args['all_task_list'] =['Carcinogenicity', 'Ames Mutagenicity', 'Respiratory toxicity',
'Eye irritation', 'Eye corrosion', 'Cardiotoxicity1', 'Cardiotoxicity5',
'Cardiotoxicity10', 'Cardiotoxicity30',
'CYP1A2', 'CYP2C19', 'CYP2C9', 'CYP2D6', 'CYP3A4',
'NR-AR', 'NR-AR-LBD', 'NR-AhR', 'NR-Aromatase', 'NR-ER', 'NR-ER-LBD',
'NR-PPAR-gamma', 'SR-ARE', 'SR-ATAD5', 'SR-HSE', 'SR-MMP', 'SR-p53',
'Acute oral toxicity (LD50)', 'LC50DM', 'BCF', 'LC50', 'IGC50']# change
args['select_task_index'] = []
for index, task in enumerate(args['all_task_list']):
if task in args['select_task_list']:
args['select_task_index'].append(index)
train_set, val_set, test_set, task_number = build_dataset.load_graph_from_csv_bin_for_splited(
bin_path=args['bin_path'],
group_path=args['group_path'],
select_task_index=args['select_task_index']
)
- 根据参数构建数据集,这里 args[‘all_task_list’] 和 args[‘select_task_list’] 完全一致
args['bin_path'],args['group_path'],args['select_task_index']
"""
('../data/toxicity.bin',
'../data/toxicity_group.csv',
[0,
1,
2,
...
30])
"""
1.1.load_graph_from_csv_bin_for_splited
def load_graph_from_csv_bin_for_splited(
bin_path='example.bin',
group_path='example.csv',
select_task_index=None):
smiles = pd.read_csv(group_path, index_col=None).smiles.values
group = pd.read_csv(group_path, index_col=None).group.to_list()
graphs, detailed_information = load_graphs(bin_path)
labels = detailed_information['labels']
mask = detailed_information['mask']
if select_task_index is not None:
labels = labels[:, select_task_index]
mask = mask[:, select_task_index]
# calculate not_use index
notuse_mask = torch.mean(mask.float(), 1).numpy().tolist()
not_use_index = []
for index, notuse in enumerate(notuse_mask):
if notuse==0:
not_use_index.append(index)
train_index=[]
val_index = []
test_index = []
for index, group_index in enumerate(group):
if group_index=='training' and index not in not_use_index:
train_index.append(index)
if group_index=='valid' and index not in not_use_index:
val_index.append(index)
if group_index == 'test' and index not in not_use_index:
test_index.append(index)
graph_List = []
for g in graphs:
graph_List.append(g)
graphs_np = np.array(graphs)
train_smiles, val_smiles, test_smiles = split_dataset_according_index(smiles, train_index, val_index, test_index)
train_labels, val_labels, test_labels = split_dataset_according_index(labels.numpy(), train_index, val_index,test_index, data_type='pd')
train_mask, val_mask, test_mask = split_dataset_according_index(mask.numpy(), train_index, val_index, test_index,data_type='pd')
train_graph, val_graph, test_graph = split_dataset_according_index(graphs_np, train_index, val_index, test_index)
# delete the 0_pos_label and 0_neg_label
task_number = train_labels.values.shape[1]
train_set = []
val_set = []
test_set = []
for i in range(len(train_index)):
molecule = [train_smiles[i], train_graph[i], train_labels.values[i], train_mask.values[i]]
train_set.append(molecule)
for i in range(len(val_index)):
molecule = [val_smiles[i], val_graph[i], val_labels.values[i], val_mask.values[i]]
val_set.append(molecule)
for i in range(len(test_index)):
molecule = [test_smiles[i], test_graph[i], test_labels.values[i], test_mask.values[i]]
test_set.append(molecule)
print(len(train_set), len(val_set), len(test_set), task_number)
return train_set, val_set, test_set, task_number
将源代码中的数据来源改成 Hepatotoxicity, select_task_index 设为 None 可以运行,打印输出:
train_set[:3],val_set[:3],test_set[:3],task_number
"""
([['O=[N+]([O-])c1cccc([N+](=O)[O-])c1', Graph(num_nodes=12, num_edges=24,
ndata_schemes={'atom': Scheme(shape=(40,), dtype=torch.int64)}
edata_schemes={'normal': Scheme(shape=(), dtype=torch.float32), 'etype': Scheme(shape=(), dtype=torch.int64)}), array([1]), array([1])],
['C#C[C@]1(O)CC[C@H]2[C@@H]3CCc4cc(OS(=O)(=O)O)ccc4[C@H]3CC[C@@]21C',
Graph(num_nodes=26, num_edges=58,
ndata_schemes={'atom': Scheme(shape=(40,), dtype=torch.int64)}
edata_schemes={'normal': Scheme(shape=(), dtype=torch.float32), 'etype': Scheme(shape=(), dtype=torch.int64)}),
array([1]),
array([1])],
['C=C1C(=CC=C2CCCC3(C)C2CCC3C(C)CCCC(C)C)CC(O)CC1O',
Graph(num_nodes=29, num_edges=62,
ndata_schemes={'atom': Scheme(shape=(40,), dtype=torch.int64)}
edata_schemes={'normal': Scheme(shape=(), dtype=torch.float32), 'etype': Scheme(shape=(), dtype=torch.int64)}),
array([1]),
array([1])]],
[['ClC1C(Cl)C(Cl)C(Cl)C(Cl)C1Cl', Graph(num_nodes=12, num_edges=24,
ndata_schemes={'atom': Scheme(shape=(40,), dtype=torch.int64)}
edata_schemes={'normal': Scheme(shape=(), dtype=torch.float32), 'etype': Scheme(shape=(), dtype=torch.int64)}), array([1]), array([1])],
['Cc1c([N+](=O)[O-])cc([N+](=O)[O-])cc1[N+](=O)[O-]',
Graph(num_nodes=16, num_edges=32,
ndata_schemes={'atom': Scheme(shape=(40,), dtype=torch.int64)}
edata_schemes={'normal': Scheme(shape=(), dtype=torch.float32), 'etype': Scheme(shape=(), dtype=torch.int64)}),
array([1]),
array([1])],
['Nc1nc(NC2CC2)c2ncn([C@H]3C=C[C@@H](CO)C3)c2n1',
Graph(num_nodes=21, num_edges=48,
ndata_schemes={'atom': Scheme(shape=(40,), dtype=torch.int64)}
edata_schemes={'normal': Scheme(shape=(), dtype=torch.float32), 'etype': Scheme(shape=(), dtype=torch.int64)}),
array([1]),
array([1])]],
[['C[C@H](N)C(=O)N[C@@H](C)C(=O)NC1[C@@H]2CN(c3nc4c(cc3F)c(=O)c(C(=O)O)cn4-c3ccc(F)cc3F)C[C@H]12',
Graph(num_nodes=40, num_edges=88,
ndata_schemes={'atom': Scheme(shape=(40,), dtype=torch.int64)}
edata_schemes={'normal': Scheme(shape=(), dtype=torch.float32), 'etype': Scheme(shape=(), dtype=torch.int64)}),
array([1]),
array([1])],
['C=CCOc1ccc(CC(=O)O)cc1Cl', Graph(num_nodes=15, num_edges=30,
ndata_schemes={'atom': Scheme(shape=(40,), dtype=torch.int64)}
edata_schemes={'normal': Scheme(shape=(), dtype=torch.float32), 'etype': Scheme(shape=(), dtype=torch.int64)}), array([1]), array([1])],
['O=C(O)CCCCCCNC1c2ccccc2CCc2ccccc21', Graph(num_nodes=25, num_edges=54,
ndata_schemes={'atom': Scheme(shape=(40,), dtype=torch.int64)}
edata_schemes={'normal': Scheme(shape=(), dtype=torch.float32), 'etype': Scheme(shape=(), dtype=torch.int64)}), array([1]), array([1])]],
1)
"""
- 这里每个数据点有 smiles,graph,label 和 mask,根据 2.3.build_mask 的分析,这里 mask 值为 0 表示 label 为 na,即数据无效
1.2.split_dataset_according_index
def split_dataset_according_index(dataset, train_index, val_index, test_index, data_type='np'):
if data_type =='pd':
return pd.DataFrame(dataset[train_index]), pd.DataFrame(dataset[val_index]), pd.DataFrame(dataset[test_index])
if data_type =='np':
return dataset[train_index], dataset[val_index], dataset[test_index]
- 构建模型所用的数据构造方法应该是下面的函数
2.built_data_and_save_for_splited
def built_data_and_save_for_splited(
origin_path='example.csv',
save_path='example.bin',
group_path='example_group.csv',
task_list_selected=None,
):
'''
origin_path: str
origin csv data set path, including molecule name, smiles, task
save_path: str
graph out put path
group_path: str
group out put path
task_list_selected: list
a list of selected task
'''
data_origin = pd.read_csv(origin_path)
smiles_name = 'smiles'
data_origin = data_origin.fillna(123456)
labels_list = [x for x in data_origin.columns if x not in ['smiles', 'group']]
if task_list_selected is not None:
labels_list = task_list_selected
data_set_gnn = multi_task_build_dataset(dataset_smiles=data_origin, labels_list=labels_list, smiles_name=smiles_name)
smiles, graphs, labels, mask, split_index = map(list, zip(*data_set_gnn))
graph_labels = {'labels': torch.tensor(labels),
'mask': torch.tensor(mask)
}
split_index_pd = pd.DataFrame(columns=['smiles', 'group'])
split_index_pd.smiles = smiles
split_index_pd.group = split_index
split_index_pd.to_csv(group_path, index=None, columns=None)
print('Molecules graph is saved!')
save_graphs(save_path, graphs, graph_labels)
- 根据原始的 csv 文件构造分子图,输出 _group.csv 和 .bin 文件
- 原始的 csv 文件应该是有多列,其中两列分别是 smiles 和 group,其他列是 task 的名字
- map(list, zip(*data_set_gnn)) 的效果可以类比下面的情况,相当于把每个分子的 smiles,graphs 等分别汇总存储
dataset=[
[1,2,3,4,5], #mol1
[6,7,8,9,10] #mol2
]
list(map(list,zip(*dataset)))
"""
[[1, 6], [2, 7], [3, 8], [4, 9], [5, 10]]
"""
2.1.multi_task_build_dataset
def multi_task_build_dataset(dataset_smiles, labels_list, smiles_name):
dataset_gnn = []
failed_molecule = []
labels = dataset_smiles[labels_list]
split_index = dataset_smiles['group']
smilesList = dataset_smiles[smiles_name]
molecule_number = len(smilesList)
for i, smiles in enumerate(smilesList):
try:
g = construct_RGCN_bigraph_from_smiles(smiles)
mask = build_mask(labels.loc[i], mask_value=123456)
molecule = [smiles, g, labels.loc[i], mask, split_index.loc[i]]
dataset_gnn.append(molecule)
print('{}/{} molecule is transformed!'.format(i+1, molecule_number))
except:
print('{} is transformed failed!'.format(smiles))
molecule_number = molecule_number - 1
failed_molecule.append(smiles)
print('{}({}) is transformed failed!'.format(failed_molecule, len(failed_molecule)))
return dataset_gnn
- 编码 smiles 为分子图并构建 mask 后返回数据集,shape 应该是 (molecule_num,5)
- 如果有多个 task,labels 应该是一个分子对应的多个任务的标签,依据 origin_path 文件进行分类
2.2.construct_RGCN_bigraph_from_smiles
def construct_RGCN_bigraph_from_smiles(smiles):
g = DGLGraph()
# Add nodes
mol = MolFromSmiles(smiles)
num_atoms = mol.GetNumAtoms()
g.add_nodes(num_atoms)
atoms_feature_all = []
for atom_index, atom in enumerate(mol.GetAtoms()):
atom_feature = atom_features(atom).tolist()
atoms_feature_all.append(atom_feature)
g.ndata["atom"] = torch.tensor(atoms_feature_all)
# Add edges
src_list = []
dst_list = []
etype_feature_all = []
num_bonds = mol.GetNumBonds()
for i in range(num_bonds):
bond = mol.GetBondWithIdx(i)
etype_feature = etype_features(bond)
u = bond.GetBeginAtomIdx()
v = bond.GetEndAtomIdx()
src_list.extend([u, v])
dst_list.extend([v, u])
etype_feature_all.append(etype_feature)
etype_feature_all.append(etype_feature)
g.add_edges(src_list, dst_list)
normal_all = []
for i in etype_feature_all:
normal = etype_feature_all.count(i)/len(etype_feature_all)
normal = round(normal, 1)
normal_all.append(normal)
g.edata["etype"] = torch.tensor(etype_feature_all)
g.edata["normal"] = torch.tensor(normal_all)
return g
- 利用 rdkit 中的 atom 对象编码原子全局特征。g.ndata[“atom”] 是原子属性的独热编码,因此一个分子 smiles 转化为分子图后 atom 特征的 shape 是 (atom_num, n),这里的 n 是 40
- 利用 rdkit 中的 bond 对象编码化学键全局特征。g.edata[“etype”] 是化学键属性的数值编码,g.edata[“normal”] 是 g.edata[“etype”] 的统计数据
- construct_RGCN_bigraph_from_smiles 中 RGCN 体现在哪里还不明朗
2.2.1.atom_features
def atom_features(atom, explicit_H = False, use_chirality=True):
results = one_of_k_encoding_unk(
atom.GetSymbol(),
[
'B',
'C',
'N',
'O',
'F',
'Si',
'P',
'S',
'Cl',
'As',
'Se',
'Br',
'Te',
'I',
'At',
'other'
]) + one_of_k_encoding(atom.GetDegree(),[0, 1, 2, 3, 4, 5, 6]) +
[atom.GetFormalCharge(), atom.GetNumRadicalElectrons()] +
one_of_k_encoding_unk(atom.GetHybridization(), [
Chem.rdchem.HybridizationType.SP, Chem.rdchem.HybridizationType.SP2,
Chem.rdchem.HybridizationType.SP3, Chem.rdchem.HybridizationType.SP3D,
Chem.rdchem.HybridizationType.SP3D2,'other']) + [atom.GetIsAromatic()]
# [atom.GetIsAromatic()] # set all aromaticity feature blank.
# In case of explicit hydrogen(QM8, QM9), avoid calling `GetTotalNumHs`
if not explicit_H:
results = results + one_of_k_encoding_unk(atom.GetTotalNumHs(),
[0, 1, 2, 3, 4])
if use_chirality:
try:
results = results + one_of_k_encoding_unk(
atom.GetProp('_CIPCode'),
['R', 'S']) + [atom.HasProp('_ChiralityPossible')]
except:
results = results + [False, False
] + [atom.HasProp('_ChiralityPossible')]
return np.array(results)
atom.GetSymbol() 获取原子的元素符号,atom.GetDegree() 获取原子连接的键,对原子的一些全局属性进行 one-hot 编码
2.2.2.one_of_k_encoding_unk
def one_of_k_encoding_unk(x, allowable_set):
"""Maps inputs not in the allowable set to the last element."""
if x not in allowable_set:
x = allowable_set[-1]
return [x == s for s in allowable_set]
- 根据 allowable_set 进行 one-hot 编码,如果不在 allowable_set 最后一个编码点会被设为 1
2.2.3.one_of_k_encoding
def one_of_k_encoding(x, allowable_set):
if x not in allowable_set:
raise Exception("input {0} not in allowable set{1}:".format(
x, allowable_set))
return [x == s for s in allowable_set]
- 根据 allowable_set 进行 one-hot 编码,如果不在 allowable_set 会报错
2.2.4.etype_features
def etype_features(bond, use_chirality=True, atompair=True):
bt = bond.GetBondType()
bond_feats_1 = [
bt == Chem.rdchem.BondType.SINGLE, bt == Chem.rdchem.BondType.DOUBLE,
bt == Chem.rdchem.BondType.TRIPLE, bt == Chem.rdchem.BondType.AROMATIC,
]
for i, m in enumerate(bond_feats_1):
if m == True:
a = i
bond_feats_2 = bond.GetIsConjugated()
if bond_feats_2 == True:
b = 1
else:
b = 0
bond_feats_3 = bond.IsInRing
if bond_feats_3 == True:
c = 1
else:
c = 0
index = a * 1 + b * 4 + c * 8
if use_chirality:
bond_feats_4 = one_of_k_encoding_unk(
str(bond.GetStereo()),
["STEREONONE", "STEREOANY", "STEREOZ", "STEREOE"])
for i, m in enumerate(bond_feats_4):
if m == True:
d = i
index = index + d * 16
if atompair == True:
atom_pair_str = bond.GetBeginAtom().GetSymbol() + bond.GetEndAtom().GetSymbol()
bond_feats_5 = one_of_k_atompair_encoding(
atom_pair_str, [['CC'], ['CN', 'NC'], ['ON', 'NO'], ['CO', 'OC'], ['CS', 'SC'],
['SO', 'OS'], ['NN'], ['SN', 'NS'], ['CCl', 'ClC'], ['CF', 'FC'],
['CBr', 'BrC'], ['others']]
)
for i, m in enumerate(bond_feats_5):
if m == True:
e = i
index = index + e*64
return index
以十进制编码数字化学键特征,index 是二进制数字 e0dcb0a 转化为 十进制后的值,即 i n d e x = a × 2 0 + 0 × 2 1 + b × 2 2 + c × 2 3 + d × 2 4 + 0 × 2 5 + e × 2 6 index=a\times 2^0+0\times 2^1+b\times 2^2+c\times 2^3+d\times 2^4+0\times 2^5+e\times 2^6 index=a×20+0×21+b×22+c×23+d×24+0×25+e×26
2.3.build_mask
def build_mask(labels_list, mask_value=100):
mask = []
for i in labels_list:
if i==mask_value:
mask.append(0)
else:
mask.append(1)
return mask
- 如果 label 的值无效,mask 就是 0,有效的话 mask 的值是 1。
- 之前在 built_data_and_save_for_splited 函数中 进行了 data_origin.fillna(123456),而这里 mask = build_mask(labels.loc[i], mask_value=123456)