使用CSV文件建立PYG数据集,进行分子预测

使用CSV文件建立PYG数据集,进行分子预测

使用自己的分子数据集(csv,sdf,mol2)来创建适用于PyG的任务。
需要的第三方库
rdkit, pytorch, pyg,



前言

在使用图神经网络进行任务是,我们有些仅仅使用轮子,所以要使用自己的数据集来预测某些指标。所以我们就需要用自己的数据集建立适合PyG的Dataset。


一.步骤

PyG有两种方法建立数据集,一种是直接读到内存中去的InMemoryDataset,限制是你的内存大小,还有一种是建立比较大的数据集的Dataset。这次我主要向分享比较全能使用的Dataset。

代码如下(示例):

import pandas as pd
import torch
from torch_geometric.data import Dataset, Data
import numpy as np
from scipy.sparse import coo_matrix
import os
from rdkit import Chem
from rdkit.Chem.rdmolops import GetAdjacencyMatrix


class mydataset(Dataset):
    def __init__(self, root, transform=None, pre_transform=None):
        super(mydataset, self).__init__(root, transform, pre_transform)

    @property
    def raw_file_names(self):
        return 'fda1.csv'

    @property
    def processed_file_names(self):
        self.data = pd.read_csv(self.raw_paths[0]).reset_index()
        return [f'data_{i}.pt' for i in list(self.data.index)]

    def download(self):
        pass

    def process(self):
        data = pd.read_csv(self.raw_paths[0])
        smiles = data['smiles'].values.tolist()
        label = self.get_label(data['values']).tolist()[0]
        for index, (smi, y) in enumerate(zip(smiles, label)):
            mol = Chem.MolFromSmiles(smi)
            n_nodes = mol.GetNumAtoms()
            n_edges = 2 * mol.GetNumBonds()
            unrelated_smiles = 'O=O'
            unrelated_mol = Chem.MolFromSmiles(unrelated_smiles)
            n_node_features = len(self.get_atom_features(unrelated_mol.GetAtomWithIdx(0)))
            n_edge_features = len(self.get_bond_feature(unrelated_mol.GetBondBetweenAtoms(0, 1)))
            x = np.zeros((n_nodes, n_node_features))
            y = np.array(y)
            y = torch.tensor(np.array([y]), dtype=torch.long)
            for atom in mol.GetAtoms():
                x[atom.GetIdx(), :] = self.get_atom_features(atom)
            x = torch.tensor(x, dtype=torch.float)

            A = GetAdjacencyMatrix(mol)
            coo_A = coo_matrix(A)
            edge_index = np.array([coo_A.row.tolist(), coo_A.col.tolist()])
            edge_index = torch.tensor(edge_index, dtype=torch.long)
            rows, cols = np.nonzero(A)
            EF = np.zeros((n_edges, n_edge_features))
            for k, (i, j) in enumerate(zip(rows, cols)):
                EF[k] = self.get_bond_feature(mol.GetBondBetweenAtoms(int(i), int(j)))
            edge_weight = torch.tensor(EF, dtype=torch.float)
            data = Data(x=x, edge_attr=edge_weight, edge_index=edge_index, y=y)
            torch.save(data, os.path.join(self.processed_dir, f'data{index}.pt'))

    def get_label(self, label):
        label = np.asarray([label])
        label = -np.log10(label * 1e-9)
        y = np.where(label > 4, 1, 0)
        return y

    def one_hot_k_encode(self, x, permitted_list):

        if x not in permitted_list:
            x = permitted_list[-1]
        binary_encoding = [int(boolean_value) for boolean_value in list(map(lambda s: x == s, permitted_list))]
        return binary_encoding

    def get_atom_features(self, atom, use_chirality=True, hydrogens_implicit=True):

        permitted_list_of_atoms = ['C', 'N', 'O', 'S', 'Cl', 'I', 'Br', 'F', 'P', 'UNW']

        if not hydrogens_implicit:
            permitted_list_of_atoms = ['H'] + permitted_list_of_atoms
        atom_type_encode = self.one_hot_k_encode(str(atom.GetSymbol()), permitted_list_of_atoms)
        n_heavy_neighbors_enc = self.one_hot_k_encode(int(atom.GetDegree()), [0, 1, 2, 3, 4, 5, 6])
        formal_charge_enc = self.one_hot_k_encode(int(atom.GetFormalCharge()), [-1, -2, -3, 0, 1, 2, 3, 4, 5, 6, 'Extreme'])
        hybridisation_type_enc = self.one_hot_k_encode(str(atom.GetHybridization()), ['S', 'SP', 'SP2', 'SP3', 'SP3D', 'SP3D2', 'Other'])
        is_in_ring_enc = [int(atom.IsInRing())]
        is_aromatic_enc = [int(atom.GetIsAromatic())]
        atomic_mass_scaled = [float((atom.GetMass() - 12.011) / 126.904)]
        vdw_radius_enc = [float((Chem.GetPeriodicTable().GetRvdw(atom.GetAtomicNum()) - 1.5) / 0.6)]
        covalent_radius_enc = [float((Chem.GetPeriodicTable().GetRcovalent(atom.GetAtomicNum()) - 0.64) / 0.76)]
        atom_feature_vector = atom_type_encode + n_heavy_neighbors_enc + formal_charge_enc + hybridisation_type_enc + is_in_ring_enc + is_aromatic_enc + atomic_mass_scaled + vdw_radius_enc + covalent_radius_enc

        if use_chirality:
            chirality_type_enc = self.one_hot_k_encode(str(atom.GetChiralTag()), ["CHI_UNSPECIFIED", "CHI_TETRAHEDRAL_CW", "CHI_TETRAHEDRAL_CCW", "CHI_OTHER"])
            atom_feature_vector += chirality_type_enc

        if hydrogens_implicit:
            n_hydrogens_enc = self.one_hot_k_encode(int(atom.GetTotalNumHs()), [0, 1, 2, 3, 4, 5, 6, 'MoreThanSix'])
            atom_feature_vector += n_hydrogens_enc

        return np.array(atom_feature_vector)

    def get_bond_feature(self, bond, use_stereochemistry=True):

        permitted_list_of_bond_types = [Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE, Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC]
        bond_type_enc = self.one_hot_k_encode(bond.GetBondType(), permitted_list_of_bond_types)
        bond_is_conj_enc = [int(bond.GetIsConjugated())]
        bond_is_ring_enc = [int(bond.IsInRing())]
        bond_feature_vector = bond_type_enc + bond_is_ring_enc + bond_is_conj_enc

        if use_stereochemistry:
            stereo_type_enc = self.one_hot_k_encode(str(bond.GetStereo()), ["STEREOZ", "STEREOE", "STEREOANY", "STEREONONE"])
            bond_feature_vector += stereo_type_enc

        return np.array(bond_feature_vector)

    def len(self):
        return self.data.shape[0]

    def get(self, idx):
        data = torch.load(os.path.join(self.processed_dir, f'data{idx}.pt'))
        return data

if __name__ == '__main__':
    datasets = mydataset('data/')
    print(datasets[1].x)
    print(datasets[1].edge_index.size())
    print(datasets[1].edge_attr.size())
    print(datasets[1].y)

二.输出

代码如下(示例):
在这里插入图片描述
在这里插入图片描述

Processing...
Done!
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [1., 0., 0.,  ..., 0., 0., 0.],
        [1., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [1., 0., 0.,  ..., 0., 0., 0.],
        [0., 1., 0.,  ..., 0., 0., 0.],
        [1., 0., 0.,  ..., 0., 0., 0.]])
torch.Size([2, 56])
torch.Size([56, 10])
tensor([0])

以上主要是,csv格式
在这里插入图片描述
smiles,和IC50值(nM),
以上主要是将smiles格式转化为mol格式,然后提取原子和键的特征。作为节点特征和边的权重。然后储存为Dataset格式。


有什么不对的,或者有什么不懂得欢迎一起探讨。

  • 5
    点赞
  • 18
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

battle不停息

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值