实现GAT批处理版--pytorch

哈哈

文章目录


前言

去年实现的批处理版GAT一直想整理成博客来着,之前太忙了一直拖到现在(真相是懒得写)。因为之前在复现论文时需要用到GAT,发现基本上都没有批处理版的,所以硬着头皮在前人的基础上改了一版,我自己训练出来效果还挺好的。说到这里,我也很好奇为什么网上没有批处理版的图神经网络,基本上都是靠拼接来实现的批处理,会不会不采用批处理的方式有什么我不知道的优点?这也是一个值得思考的问题。但是很多实现细节我也忘了,所以这是一篇只有代码的博客,这件事情告诉我们,一定要早点写博客!


代码

import torch
import torch.nn as nn
import torch.nn.functional as F

class GraphAttentionLayer(nn.Module):
    def __init__(self,in_feature,out_feature,dropout,aplha,concat=True):
        super(GraphAttentionLayer,self).__init__()
        self.in_feature=in_feature
        self.out_feature=out_feature
        self.dropout=dropout
        self.alpha=aplha
        self.concat=concat

        self.Wlinear=nn.Linear(in_feature,out_feature)
        # self.W=nn.Parameter(torch.empty(size=(batch_size,in_feature,out_feature)))
        nn.init.xavier_uniform_(self.Wlinear.weight,gain=1.414)

        self.aiLinear=nn.Linear(out_feature,1)
        self.ajLinear=nn.Linear(out_feature,1)
        # self.a=nn.Parameter(torch.empty(size=(batch_size,2*out_feature,1)))
        nn.init.xavier_uniform_(self.aiLinear.weight,gain=1.414)
        nn.init.xavier_uniform_(self.ajLinear.weight,gain=1.414)

        self.leakyRelu=nn.LeakyReLU(self.alpha)


    def getAttentionE(self,Wh):
        #重点改了这个函数
        Wh1=self.aiLinear(Wh)
        Wh2=self.ajLinear(Wh)
        Wh2=Wh2.view(Wh2.shape[0],Wh2.shape[2],Wh2.shape[1])
        # Wh1=torch.bmm(Wh,self.a[:,:self.out_feature,:])    #Wh:size(node,out_feature),a[:out_eature,:]:size(out_feature,1) => Wh1:size(node,1)
        # Wh2=torch.bmm(Wh,self.a[:,self.out_feature:,:])    #Wh:size(node,out_feature),a[out_eature:,:]:size(out_feature,1) => Wh2:size(node,1)

        e=Wh1+Wh2   #broadcast add, => e:size(node,node)
        return self.leakyRelu(e)

    def forward(self,h,adj):
        # print(h.shape)
        Wh=self.Wlinear(h)
        # Wh=torch.bmm(h,self.W)   #h:size(node,in_feature),W:size(in_feature,out_feature) => Wh:size(node,out_feature)
        e=self.getAttentionE(Wh)

        zero_vec=-1e9*torch.ones_like(e)
        attention=torch.where(adj>0,e,zero_vec)
        attention=F.softmax(attention,dim=2)
        attention=F.dropout(attention,self.dropout,training=self.training)
        h_hat=torch.bmm(attention,Wh)  #attention:size(node,node),Wh:size(node,out_fature) => h_hat:size(node,out_feature)

        if self.concat:
            return F.elu(h_hat)
        else:
            return h_hat

    def __repr__(self):
        return self.__class__.__name__+' ('+str(self.in_feature)+'->'+str(self.out_feature)+')'


class GAT(nn.Module):
    def __init__(self,in_feature,hidden_feature,out_feature,attention_layers,dropout,alpha):
        super(GAT,self).__init__()
        self.in_feature=in_feature
        self.out_feature=out_feature
        self.hidden_feature=hidden_feature
        self.dropout=dropout
        self.alpha=alpha
        self.attention_layers=attention_layers

        self.attentions=[GraphAttentionLayer(in_feature,hidden_feature,dropout,alpha,True) for i in range(attention_layers)]

        for i,attention in enumerate(self.attentions):
            self.add_module('attention_{}'.format(i),attention)

        self.out_attention=GraphAttentionLayer(attention_layers*hidden_feature,out_feature,dropout,alpha,False)


    def forward(self,h,adj):
        # print(h)
        h=F.dropout(h,self.dropout,training=self.training)

        h=torch.cat([attention(h,adj) for attention in self.attentions],dim=2)
        h=F.dropout(h,self.dropout,training=self.training)
        h=F.elu(self.out_attention(h,adj))
        return h

  • 7
    点赞
  • 15
    收藏
    觉得还不错? 一键收藏
  • 8
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 8
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值