节点表征学习[GCN、GAT]

节点表征学习[GCN、GAT]

引言

在这里插入图片描述
图神经网络机器学习、深度学习的根本目的是学习图数据,将图结构embedding为计算机可处理的向量矩阵。
在这里插入图片描述
以空域的卷积为例,每一次卷积之后得到的的特征变换实际上是改变了节点的隐状态。得到每个节点聚合和更新之后的隐状态后就可以执行下游的任务(节点分类,边分类,图分类)。本次将比较只是用NN和用图神经网络算法(GCN、GAT)的差异。

节点表征学习

在这里插入图片描述

GCN的公式,在pyG中很好实现,下面看代码

from torch_geometric.nn import GCNConv
import torch

class GCN(torch.nn.Module):
  def __init__(self, hidden_channels):
      super(GCN, self).__init__()
      torch.manual_seed(12345)
      self.conv1 = GCNConv(dataset.num_features,hidden_channels)
      self.conv2 = GCNConv(hidden_channels,dataset.num_classes)

  def forward(self, x, edge_index):
      x = self.conv1(x, edge_index)
      x = x.relu()
      x = F.dropout(x, p=0.5, training=self.training)
      x = self.conv2(x, edge_index)
      return x

我们只需要继承nn.Module这个库,调用pyG的GCNConv当做线性层使用就可以了。
让我们来看看GCNConv的源码,只看Class部分

class GCNConv(MessagePassing):

  def __init__(self, in_channels: int, out_channels: int,
               improved: bool = False, cached: bool = False,
               add_self_loops: bool = True, normalize: bool = True,
               bias: bool = True, **kwargs):

      kwargs.setdefault('aggr', 'add')
      super(GCNConv, self).__init__(**kwargs)

      self.in_channels = in_channels
      self.out_channels = out_channels
      self.improved = improved
      self.cached = cached
      self.add_self_loops = add_self_loops
      self.normalize = normalize

      self._cached_edge_index = None
      self._cached_adj_t = None

      self.weight = Parameter(torch.Tensor(in_channels, out_channels))

      if bias:
          self.bias = Parameter(torch.Tensor(out_channels))
      else:
          self.register_parameter('bias', None)

      self.reset_parameters()

  def reset_parameters(self):
      glorot(self.weight)
      zeros(self.bias)
      self._cached_edge_index = None
      self._cached_adj_t = None

  def forward(self, x: Tensor, edge_index: Adj,
              edge_weight: OptTensor = None) -> Tensor:
      """"""

      if self.normalize:
          if isinstance(edge_index, Tensor):
              cache = self._cached_edge_index
              if cache is None:
                  edge_index, edge_weight = gcn_norm(  # yapf: disable
                      edge_index, edge_weight, x.size(self.node_dim),
                      self.improved, self.add_self_loops)
                  if self.cached:
                      self._cached_edge_index = (edge_index, edge_weight)
              else:
                  edge_index, edge_weight = cache[0], cache[1]

          elif isinstance(edge_index, SparseTensor):
              cache = self._cached_adj_t
              if cache is None:
                  edge_index = gcn_norm(  # yapf: disable
                      edge_index, edge_weight, x.size(self.node_dim),
                      self.improved, self.add_self_loops)
                  if self.cached:
                      self._cached_adj_t = edge_index
              else:
                  edge_index = cache

      x = x @ self.weight

      # propagate_type: (x: Tensor, edge_weight: OptTensor)
      out = self.propagate(edge_index, x=x, edge_weight=edge_weight,
                           size=None)

      if self.bias is not None:
          out += self.bias

      return out

  def message(self, x_j: Tensor, edge_weight: OptTensor) -> Tensor:
      return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j

  def message_and_aggregate(self, adj_t: SparseTensor, x: Tensor) -> Tensor:
      return matmul(adj_t, x, reduce=self.aggr)

  def __repr__(self):
      return '{}({}, {})'.format(self.__class__.__name__, self.in_channels,
                                 self.out_channels)

上节就说到在pyG中要实现消息传递需要继承MessagePassing 方法,同样的这里GCNConv类同样继承了。整个forward过程就是实现了GCN的公式

GAT公式·在这里插入图片描述

import torch
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.nn import GATConv
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures

dataset = Planetoid(root='data/Planetoid', name='Cora',
transform=NormalizeFeatures())

class GAT(torch.nn.Module):
    def __init__(self, hidden_channels):
        super(GAT, self).__init__()
        torch.manual_seed(12345)
        self.conv1 = GATConv(dataset.num_features, hidden_channels)
        self.conv2 = GATConv(hidden_channels, dataset.num_classes)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.conv2(x, edge_index)
        return x
  • 0
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值