使用CSV文件建立PYG数据集,进行分子预测
使用自己的分子数据集(csv,sdf,mol2)来创建适用于PyG的任务。
需要的第三方库
rdkit, pytorch, pyg,
前言
在使用图神经网络进行任务是,我们有些仅仅使用轮子,所以要使用自己的数据集来预测某些指标。所以我们就需要用自己的数据集建立适合PyG的Dataset。
一.步骤
PyG有两种方法建立数据集,一种是直接读到内存中去的InMemoryDataset,限制是你的内存大小,还有一种是建立比较大的数据集的Dataset。这次我主要向分享比较全能使用的Dataset。
代码如下(示例):
import pandas as pd
import torch
from torch_geometric.data import Dataset, Data
import numpy as np
from scipy.sparse import coo_matrix
import os
from rdkit import Chem
from rdkit.Chem.rdmolops import GetAdjacencyMatrix
class mydataset(Dataset):
def __init__(self, root, transform=None, pre_transform=None):
super(mydataset, self).__init__(root, transform, pre_transform)
@property
def raw_file_names(self):
return 'fda1.csv'
@property
def processed_file_names(self):
self.data = pd.read_csv(self.raw_paths[0]).reset_index()
return [f'data_{i}.pt' for i in list(self.data.index)]
def download(self):
pass
def process(self):
data = pd.read_csv(self.raw_paths[0])
smiles = data['smiles'].values.tolist()
label = self.get_label(data['values']).tolist()[0]
for index, (smi, y) in enumerate(zip(smiles, label)):
mol = Chem.MolFromSmiles(smi)
n_nodes = mol.GetNumAtoms()
n_edges = 2 * mol.GetNumBonds()
unrelated_smiles = 'O=O'
unrelated_mol = Chem.MolFromSmiles(unrelated_smiles)
n_node_features = len(self.get_atom_features(unrelated_mol.GetAtomWithIdx(0)))
n_edge_features = len(self.get_bond_feature(unrelated_mol.GetBondBetweenAtoms(0, 1)))
x = np.zeros((n_nodes, n_node_features))
y = np.array(y)
y = torch.tensor(np.array([y]), dtype=torch.long)
for atom in mol.GetAtoms():
x[atom.GetIdx(), :] = self.get_atom_features(atom)
x = torch.tensor(x, dtype=torch.float)
A = GetAdjacencyMatrix(mol)
coo_A = coo_matrix(A)
edge_index = np.array([coo_A.row.tolist(), coo_A.col.tolist()])
edge_index = torch.tensor(edge_index, dtype=torch.long)
rows, cols = np.nonzero(A)
EF = np.zeros((n_edges, n_edge_features))
for k, (i, j) in enumerate(zip(rows, cols)):
EF[k] = self.get_bond_feature(mol.GetBondBetweenAtoms(int(i), int(j)))
edge_weight = torch.tensor(EF, dtype=torch.float)
data = Data(x=x, edge_attr=edge_weight, edge_index=edge_index, y=y)
torch.save(data, os.path.join(self.processed_dir, f'data{index}.pt'))
def get_label(self, label):
label = np.asarray([label])
label = -np.log10(label * 1e-9)
y = np.where(label > 4, 1, 0)
return y
def one_hot_k_encode(self, x, permitted_list):
if x not in permitted_list:
x = permitted_list[-1]
binary_encoding = [int(boolean_value) for boolean_value in list(map(lambda s: x == s, permitted_list))]
return binary_encoding
def get_atom_features(self, atom, use_chirality=True, hydrogens_implicit=True):
permitted_list_of_atoms = ['C', 'N', 'O', 'S', 'Cl', 'I', 'Br', 'F', 'P', 'UNW']
if not hydrogens_implicit:
permitted_list_of_atoms = ['H'] + permitted_list_of_atoms
atom_type_encode = self.one_hot_k_encode(str(atom.GetSymbol()), permitted_list_of_atoms)
n_heavy_neighbors_enc = self.one_hot_k_encode(int(atom.GetDegree()), [0, 1, 2, 3, 4, 5, 6])
formal_charge_enc = self.one_hot_k_encode(int(atom.GetFormalCharge()), [-1, -2, -3, 0, 1, 2, 3, 4, 5, 6, 'Extreme'])
hybridisation_type_enc = self.one_hot_k_encode(str(atom.GetHybridization()), ['S', 'SP', 'SP2', 'SP3', 'SP3D', 'SP3D2', 'Other'])
is_in_ring_enc = [int(atom.IsInRing())]
is_aromatic_enc = [int(atom.GetIsAromatic())]
atomic_mass_scaled = [float((atom.GetMass() - 12.011) / 126.904)]
vdw_radius_enc = [float((Chem.GetPeriodicTable().GetRvdw(atom.GetAtomicNum()) - 1.5) / 0.6)]
covalent_radius_enc = [float((Chem.GetPeriodicTable().GetRcovalent(atom.GetAtomicNum()) - 0.64) / 0.76)]
atom_feature_vector = atom_type_encode + n_heavy_neighbors_enc + formal_charge_enc + hybridisation_type_enc + is_in_ring_enc + is_aromatic_enc + atomic_mass_scaled + vdw_radius_enc + covalent_radius_enc
if use_chirality:
chirality_type_enc = self.one_hot_k_encode(str(atom.GetChiralTag()), ["CHI_UNSPECIFIED", "CHI_TETRAHEDRAL_CW", "CHI_TETRAHEDRAL_CCW", "CHI_OTHER"])
atom_feature_vector += chirality_type_enc
if hydrogens_implicit:
n_hydrogens_enc = self.one_hot_k_encode(int(atom.GetTotalNumHs()), [0, 1, 2, 3, 4, 5, 6, 'MoreThanSix'])
atom_feature_vector += n_hydrogens_enc
return np.array(atom_feature_vector)
def get_bond_feature(self, bond, use_stereochemistry=True):
permitted_list_of_bond_types = [Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE, Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC]
bond_type_enc = self.one_hot_k_encode(bond.GetBondType(), permitted_list_of_bond_types)
bond_is_conj_enc = [int(bond.GetIsConjugated())]
bond_is_ring_enc = [int(bond.IsInRing())]
bond_feature_vector = bond_type_enc + bond_is_ring_enc + bond_is_conj_enc
if use_stereochemistry:
stereo_type_enc = self.one_hot_k_encode(str(bond.GetStereo()), ["STEREOZ", "STEREOE", "STEREOANY", "STEREONONE"])
bond_feature_vector += stereo_type_enc
return np.array(bond_feature_vector)
def len(self):
return self.data.shape[0]
def get(self, idx):
data = torch.load(os.path.join(self.processed_dir, f'data{idx}.pt'))
return data
if __name__ == '__main__':
datasets = mydataset('data/')
print(datasets[1].x)
print(datasets[1].edge_index.size())
print(datasets[1].edge_attr.size())
print(datasets[1].y)
二.输出
代码如下(示例):
Processing...
Done!
tensor([[0., 0., 0., ..., 0., 0., 0.],
[1., 0., 0., ..., 0., 0., 0.],
[1., 0., 0., ..., 0., 0., 0.],
...,
[1., 0., 0., ..., 0., 0., 0.],
[0., 1., 0., ..., 0., 0., 0.],
[1., 0., 0., ..., 0., 0., 0.]])
torch.Size([2, 56])
torch.Size([56, 10])
tensor([0])
以上主要是,csv格式
smiles,和IC50值(nM),
以上主要是将smiles格式转化为mol格式,然后提取原子和键的特征。作为节点特征和边的权重。然后储存为Dataset格式。
有什么不对的,或者有什么不懂得欢迎一起探讨。