CS224W作业 class GAT(MessagePassing)

2023年CS22Wassignment中的所有colab答案以及注释已经上传到github:https://github.com/yuyu990116/CS224W-assignment
CS224W课程地址:http://web.stanford.edu/class/cs224w/

class GAT(MessagePassing):

    def __init__(self, in_channels, out_channels, heads = 2,
                 negative_slope = 0.2, dropout = 0., **kwargs):
        super(GAT, self).__init__(node_dim=0, **kwargs)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.heads = heads
        self.negative_slope = negative_slope
        self.dropout = dropout
        self.lin_l=Linear(in_channels,heads*out_channels) #W_l
        self.lin_r = self.lin_l  #W_r
        self.att_l = Parameter(torch.Tensor(1, heads, out_channels)) #a_l
        self.att_r = Parameter(torch.Tensor(1, heads, out_channels)) #a_r
        self.reset_parameters()
	def reset_parameters(self):
        nn.init.xavier_uniform_(self.lin_l.weight)
        nn.init.xavier_uniform_(self.lin_r.weight)
        nn.init.xavier_uniform_(self.att_l)
        nn.init.xavier_uniform_(self.att_r)
    def forward(self, x, edge_index, size = None):
        H, C = self.heads, self.out_channels
        x_l=self.lin_l(x)
        x_r=self.lin_r(x)
        x_l=x_l.view(-1,H,C)
        x_r=x_r.view(-1,H,C)
        alpha_l = (x_l * self.att_l).sum(axis=1)  #α_l
        alpha_r = (x_r * self.att_r).sum(axis=1)  #α_r
        out = self.propagate(edge_index, x=(x_l, x_r), alpha=(alpha_l, alpha_r),size=size)
        out = out.view(-1, H * C)
        return out
    def message(self, x_j, alpha_j, alpha_i, index, ptr, size_i):
        alpha = alpha_i + alpha_j 
        alpha = F.leaky_relu(alpha,self.negative_slope)
        #leakeyrelu的负斜率
        alpha = softmax(alpha, index, ptr, size_i)
        alpha = F.dropout(alpha, p=self.dropout, training=self.training).unsqueeze(1)  #[N,1,C]
        out = x_j * alpha 
        return out
  def aggregate(self, inputs, index, dim_size = None):
          out = torch_scatter.scatter(inputs, index, dim=self.node_dim, dim_size=dim_size, reduce='sum')
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值