2021秋季课程cs224w(图机器学习) colab4 GAT

官网链接:cs244

colab4里需要自己实现的部分是单层GAT,下面就是GAT单层实现部分:
参考pyG的具体实现即可:https://pytorchgeometric.readthedocs.io/en/latest/_modules/torch_geometric/nn/conv/gat_conv.html#GATConv

MessagePassing类的各种函数定义就不放了,colab里有。下面是相关公式和代码

公式:
GAT可以带权重的聚合节点的邻居信息,总公式为公式(1):在这里插入图片描述

αij为注意力权重:公式(2):
在这里插入图片描述
上面的两个公式就是下面要实现的过程(采用Transtive方式进行训练)
(下面是按照无向图实现的)

实现逻辑:
1、对每个节点应用线性转换(计算Wh)
在这里插入图片描述
将维度为[N,F]的特征矩阵,转换为[N,F’]的特征矩阵

forward函数中:
w_l = self.lin_l(x).view(-1, H, C)#[N,H,C]
w_r = self.lin_r(x).view(-1, H, C)#[N,H,C]

2、计算注意力(在message函数中计算):按照公式(2)逐步计算
对于一组点(i,j),首先与要计算alpha_l和alpha_r,即:


alpha_l = (w_l * self.att_l).sum(-1) #alpha_l的维度:[N,H]
alpha_r = (w_r * self.att_r).sum(-1)#[N,H]

之后只用LeakRelu:

message函数:
alpha = F.leaky_relu(alpha, self.negative_slope)
alpha = softmax(alpha, index, ptr, size_i)
alpha = F.dropout(alpha, p=self.dropout, training=self.training)
#此时alpha维度为[E,H]

计算完成后的alpha就是每条边(i,j)的注意力

3、在message过程中乘上加权。
这里的E是输入图的边数,因为是Transtive的训练方式,因此E是整个图的边数。

message函数:
alpha = alpha.unsqueeze(-1) #alpha = [E,H,1](扩展最后一维)
out = x_j * alpha #[N,H,C] * [E,H,C](逐元素相乘,乘的时候,内部的[E.H,1]会被广播为[E,H,C],有疑问可查‘’pytorch中的广播机制’)
#out=[E,H,C]

4、聚合:对邻接节点的嵌入表示进行求和

单层GAT代码:

class GAT(MessagePassing):

    def __init__(self, in_channels, out_channels, heads = 2,
                 negative_slope = 0.2, dropout = 0., **kwargs):
        super(GAT, self).__init__(node_dim=0, **kwargs)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.heads = heads
        self.negative_slope = negative_slope
        self.dropout = dropout

        self.lin_l = None
        self.lin_r = None
        self.att_l = None
        self.att_r = None

        ##################定义线性层,即两个线性层和两个权重向量#####
        # TODO: Your code here! 
        self.lin_l = nn.Linear(self.in_channels, self.out_channels * self.heads)
        self.lin_r = self.lin_l
        self.att_r = Parameter(torch.Tensor(1, self.heads, self.out_channels))
        self.att_l = Parameter(torch.Tensor(1, self.heads, self.out_channels))

        ############################################################################


        ############################################################################

        self.reset_parameters()

    def reset_parameters(self):
        nn.init.xavier_uniform_(self.lin_l.weight)
        nn.init.xavier_uniform_(self.lin_r.weight)
        nn.init.xavier_uniform_(self.att_l)
        nn.init.xavier_uniform_(self.att_r)

    def forward(self, x, edge_index, size = None):
        
        H, C = self.heads, self.out_channels

        ############################################################################
       # TODO: Your code here! 
        w_l = self.lin_l(x).view(-1, H, C)
        w_r = self.lin_r(x).view(-1, H, C)
        # print(w_r.shape)
        alpha_l = (w_l * self.att_l).sum(-1)
        alpha_r = (w_r * self.att_r).sum(-1)
        
        out = self.propagate(edge_index,x = (w_l, w_r), alpha = (alpha_l, alpha_r), dim_size=size).view(-1, H * C)
        ############################################################################

        return out


    def message(self, x_j, alpha_j, alpha_i, index, ptr, size_i):

        ############################################################################
        # TODO: Your code here! 
        # alpha = [E,H]
        # x = [N,H,C]
        H, C = self.heads, self.out_channels

        alpha = alpha_i + alpha_j

        alpha = F.leaky_relu(alpha, self.negative_slope)
        alpha = softmax(alpha, index, ptr, size_i)
        alpha = F.dropout(alpha, p=self.dropout, training=self.training)

        alpha = alpha.unsqueeze(-1) #alpha = [E,H,1]
        #alpha [E, H, C]
        out = x_j * alpha #[N,H,C] * [E,H,C](alpha进行广播)
        
        ############################################################################

        return out


    def aggregate(self, inputs, index, dim_size = None):

        ############################################################################
        # TODO: Your code here! 
        
        out = torch_scatter.scatter(inputs, index, reduce='sum', dim=0)
        ############################################################################
    
        return out
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值