Spatial Graph Convolutional Networks(SGCN)

利用空间信息的图卷积网络,代码:geo-gcn,原文:Spatial Graph Convolutional Networks

#在colab里配置运行环境

1.load_dataset

from torch_geometric.data import DataLoader
from chem import load_dataset

batch_size = 64
dataset_name = ...  # 'freesolv' / 'esol' / 'bbbp'

train_dataset = load_dataset(dataset_name, 'train')
val_dataset = load_dataset(dataset_name, 'val')
test_dataset = load_dataset(dataset_name, 'test')

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

# training loop
... 
def load_dataset(dataset_name, fold_name, path='../data/molecules'):
    filename = dataset_name.lower() + '_' + fold_name + '.csv'
    filepath = os.path.join(path, filename)
    x, y = load_data_from_df(filepath)
    return transform_dataset_pg([[*i, j] for i, j in zip(x, y)])

1.1.load_data_from_df

def load_data_from_df(dataset_path):
    data_df = pd.read_csv(dataset_path)

    data_x = data_df.iloc[:, 0].values
    data_y = data_df.iloc[:, 1].values

    if data_y.dtype == np.float64:
        data_y = data_y.astype(np.float32)

    x_all, y_all = load_data_from_smiles(data_x, data_y)
    return x_all, y_all
  • 载入SMILES和预测目标,如果预测目标是np.float64类型,说明是回归模型,否则是分类模型
def load_data_from_smiles(x_smiles, labels, normalize_features=False):
    x_all, y_all = [], []
    for smiles, label in zip(x_smiles, labels):
        try:
            if len(smiles) < 2:
                raise ValueError

            mol = MolFromSmiles(smiles)

            mol = Chem.AddHs(mol)
            AllChem.EmbedMolecule(mol)
            AllChem.UFFOptimizeMolecule(mol)
            mol = Chem.RemoveHs(mol)

            afm, adj, mat_positions = featurize_mol(mol)
            x_all.append([afm, adj, mat_positions])
            y_all.append([label])
        except ValueError as e:
            logging.warning('the SMILES ({}) can not be converted to a graph.\nREASON: {}'.format(smiles, e))

    if normalize_features:
        x_all = feature_normalize(x_all)
    return x_all, y_all
  • 加氢之后产生3D构象,然后计算原子在空间中的相对位置,这里 normalize_features 默认是 False,afm是指 atom feature matrix

1.1.1.featurize_mol

def featurize_mol(mol):
    conf = mol.GetConformer()
    node_features = np.array([get_atom_features(atom)
                              for atom in mol.GetAtoms()])
    adj_matrix = np.eye(mol.GetNumAtoms())
    for bond in mol.GetBonds():
        begin_atom = bond.GetBeginAtom().GetIdx()
        end_atom = bond.GetEndAtom().GetIdx()
        adj_matrix[begin_atom, end_atom] = adj_matrix[end_atom, begin_atom] = 1
    pos_matrix = np.array([[conf.GetAtomPosition(k).x, conf.GetAtomPosition(k).y, conf.GetAtomPosition(k).z]
                           for k in range(mol.GetNumAtoms())])
    return node_features, adj_matrix, pos_matrix

def get_atom_features(atom):
    attributes = []

    attributes += one_hot_vector(
        atom.GetAtomicNum(),
        [5, 6, 7, 8, 9, 15, 16, 17, 35, 53, 999]
    )

    attributes += one_hot_vector(
        len(atom.GetNeighbors()),
        [0, 1, 2, 3, 4, 5]
    )

    attributes += one_hot_vector(
        atom.GetTotalNumHs(),
        [0, 1, 2, 3, 4]
    )

    attributes.append(atom.GetFormalCharge())
    attributes.append(atom.IsInRing())
    attributes.append(atom.GetIsAromatic())

    return np.array(attributes, dtype=np.float32)
    
def one_hot_vector(val, lst):
    """Converts a value to a one-hot vector based on options in lst"""
    if val not in lst:
        val = lst[-1]
    return map(lambda x: x == val, lst)
  • 构造特征矩阵,邻接矩阵和位置矩阵

1.1.2.feature_normalize

def feature_normalize(x_all):
    """Min Max Feature Scalling for Atom Feature Matrix"""
    min_vec, max_vec = x_all[0][0].min(axis=0), x_all[0][0].max(axis=0)
    for x in x_all:
        min_vec = np.minimum(min_vec, x[0].min(axis=0))
        max_vec = np.maximum(max_vec, x[0].max(axis=0))
    diff = max_vec - min_vec
    diff[diff == 0] = 1.

    for x in x_all:
        afm = x[0]
        afm = (afm - min_vec) / diff
        x[0] = afm

    return x_all
  • 最大最小归一化

1.2.transform_dataset_pg

def transform_dataset_pg(dataset):
    dataset_pg = []

    for mol in dataset:
        dataset_pg.append(transform_molecule_pg(mol))

    return dataset_pg
  • mol包含了四个数据,下面全部转换到torch.tensor
def transform_molecule_pg(mol):
    afm, adj, positions, label = mol

    x = torch.tensor(afm)
    y = torch.tensor(label)
    edge_index = torch.tensor(get_edge_indices(adj)).t().contiguous()
    pos = torch.tensor(positions)

    return Data(x=x, y=y, edge_index=edge_index, pos=pos)
    
def get_edge_indices(adj):
    edges_list = []
    for i in range(adj.shape[0]):
        for j in range(i, adj.shape[0]):
            if adj[i, j] == 1:
                edges_list.append((i, j))
    return edges_list

  • 构造pyg需要的边列表,以对应参数转成Data类型,这样接入了pytorch的框架
import numpy as np

adj=np.array([
    [1,0,0],
    [0,1,1],
    [0,1,1]
])
def get_edge_indices(adj):
    edges_list = []
    for i in range(adj.shape[0]):
        for j in range(i, adj.shape[0]):
            if adj[i, j] == 1:
                edges_list.append((i, j))
    return edges_list
get_edge_indices(adj) #tensor([[0, 1, 1, 2],[0, 1, 2, 2]])
torch.tensor(get_edge_indices(adj)).t().contiguous() #tensor([[0, 1, 1, 2],[0, 1, 2, 2]])
  • 边列表第一行是源节点,第二行是目标节点。这里j没有从0开始遍历,因此少了 (2,1) 这条边,实际上在 adj 里已经表达了,后面似乎也没有加???为什么?

2.SpatialGraphConv

import torch
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops


class SpatialGraphConv(MessagePassing):
    def __init__(self, coors, in_channels, out_channels, hidden_size, dropout=0):
        """
        coors - dimension of positional descriptors (e.g. 2 for 2D images)
        in_channels - number of the input channels (node features)
        out_channels - number of the output channels (node features)
        hidden_size - number of the inner convolutions
        dropout - dropout rate after the layer
        """
        super(SpatialGraphConv, self).__init__(aggr='add')
        self.dropout = dropout
        self.lin_in = torch.nn.Linear(coors, hidden_size * in_channels)
        self.lin_out = torch.nn.Linear(hidden_size * in_channels, out_channels)
        self.in_channels = in_channels

    def forward(self, x, pos, edge_index):
        """
        x - feature matrix of the whole graph [num_nodes, label_dim]
        pos - node position matrix [num_nodes, coors]
        edge_index - graph connectivity [2, num_edges]
        """
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))  # num_edges = num_edges + num_nodes

        return self.propagate(edge_index=edge_index, x=x, pos=pos, aggr='add')  # [N, out_channels, label_dim]

    def message(self, pos_i, pos_j, x_j):
        """
        pos_i [num_edges, coors]
        pos_j [num_edges, coors]
        x_j [num_edges, label_dim]
        """

        relative_pos = pos_j - pos_i  # [n_edges, coors]
        spatial_scaling = F.relu(self.lin_in(relative_pos))  # [n_edges, hidden_size * in_channels]

        n_edges = spatial_scaling.size(0)
        # [n_edges, in_channels, ...] * [n_edges, in_channels, 1]
        result = spatial_scaling.reshape(n_edges, self.in_channels, -1) * x_j.unsqueeze(-1)
        return result.view(n_edges, -1)

    def update(self, aggr_out):
        """
        aggr_out [num_nodes, label_dim, out_channels]
        """
        aggr_out = self.lin_out(aggr_out)  # [num_nodes, label_dim, out_features]
        aggr_out = F.relu(aggr_out)
        aggr_out = F.dropout(aggr_out, p=self.dropout, training=self.training)

        return aggr_out
  • 官网自定义MPNN的解释机制实现示例
  • SGCN关键公式是 h i ( U , b ) = ∑ j ∈ N i R e L U ( U T ( p j − p i ) + b ) ⊙ h j h_i(U,b)=\sum_{j\in N_i}ReLU(U^T(p_j-p_i)+b)\odot h_j hi(U,b)=jNiReLU(UT(pjpi)+b)hj
  • 一个SGCN层输入 x,pos,edge_index,输出 aggr_out,shape是 [num_nodes, label_dim, out_features]
  • 参数后缀很重要,i表示与target节点相关的参数,j表示source节点相关的参数
  • 通过Inspector类收集参数,aggr=‘add’ 而没有自定义 aggregate 函数,这个函数对 shape 做了一些变化,result 的shape是(n_edges,hidden_size * in_channels),然后经过 aggregate 输出aggr_out,然后经过nn.Linear(hidden_size * in_channels, out_channels),后面注释的shape就变成了[num_nodes, label_dim, out_features]?中间的shape变化过程没有理解,应该可以参考这里提到的add默认函数
def aggregate(self, x_j, edge_index):
    row, _ = edge_index
    aggr_out = scatter(x_j, col, dim=-2, reduce='sum')
    return aggr_out

3.normalized_cut_2d

def normalized_cut_2d(edge_index, pos):
    row, col = edge_index
    edge_attr = torch.norm(pos[row] - pos[col], p=2, dim=1)
    return normalized_cut(edge_index, edge_attr, num_nodes=pos.size(0))

4.SGCN

class SGCN(torch.nn.Module):
    def __init__(self, dim_coor, out_dim, input_features,
                 layers_num, model_dim, out_channels_1, dropout,
                 use_cluster_pooling):
        super(SGCN, self).__init__()
        self.layers_num = layers_num
        self.use_cluster_pooling = use_cluster_pooling

        self.conv_layers = [SpatialGraphConv(coors=dim_coor,
                                             in_channels=input_features,
                                             out_channels=model_dim,
                                             hidden_size=out_channels_1,
                                             dropout=dropout)] + \
                           [SpatialGraphConv(coors=dim_coor,
                                             in_channels=model_dim,
                                             out_channels=model_dim,
                                             hidden_size=out_channels_1,
                                             dropout=dropout) for _ in range(layers_num - 1)]

        self.conv_layers = torch.nn.ModuleList(self.conv_layers)

        self.fc1 = torch.nn.Linear(model_dim, out_dim)

    def forward(self, data):
        for i in range(self.layers_num):
            data.x = self.conv_layers[i](data.x, data.pos, data.edge_index)

            if self.use_cluster_pooling:
                weight = normalized_cut_2d(data.edge_index, data.pos)
                cluster = graclus(data.edge_index, weight, data.x.size(0))
                data = max_pool(cluster, data, transform=T.Cartesian(cat=False))

        data.x = global_mean_pool(data.x, data.batch)
        x = self.fc1(data.x)

        return F.log_softmax(x, dim=1)
  • 1
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 5
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

_森罗万象

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值