异构图如何使用attention

@[图神经网络](dgl 异构图,如果加入attention)

本次测试使用两种类型节点 句子节点,单词节点

多头注意力参考这里

import torch
import dgl
import dgl.function as fn
import torch.nn as nn
import torch.nn.functional as F
def generate_empty_graph(num):
    empty_graphs=[]
    for i in range(num):
        g = dgl.heterograph({
            ("sentence", "s_s", "sentence"): (torch.tensor([0]), torch.tensor([0])),
            ("word", "w_w", "word"): (torch.tensor([0]), torch.tensor([0])),
            ("word", "w_s", "sentence"): (torch.tensor([0]), torch.tensor([0])),
            ("sentence", "s_w", "word"): (torch.tensor([0]), torch.tensor([0]))
        })
        g.nodes["word"].data["feats"] = torch.tensor([[0.0 for k in range(768)]]).float()
        g.nodes["sentence"].data["feats"] = torch.tensor([[0.0 for k in range(768)]]).float()
        empty_graphs.append(g)
    return empty_graphs


def build_graph():
    # print("-----------------------------------------------------build graph-------------------------------------------------------")
    print("example:","I am tiger .")
    g=dgl.heterograph({
        ("sentence", "s_s", "sentence"): (torch.tensor([0]), torch.tensor([0])),  # 句子与句子之间的边  自连接
        ("word","w_w","word"):(torch.tensor([0,1,2,3,0,1,2,1,2,3]),torch.tensor([0,1,2,3,1,2,3,0,1,2])),
        ("word","w_s","sentence"):(torch.tensor([0,1,2,3]),torch.tensor([0,0,0,0])),
        ("sentence", "s_w", "word"): (torch.tensor([0,0,0,0]), torch.tensor([0,1,2,3]))
    })
    g.nodes["sentence"].data["feats"]=torch.tensor([[1.0 for i in range(768)]]).float()
    g.nodes["word"].data["feats"]=torch.tensor([[1.0 for i in range(768)],[1.0 for i in range(768)],[1.0 for i in range(768)],[1.0 for i in range(768)]]).float()
    # g=dgl.remove_self_loop(g,etype="self_loop")
    # g=dgl.add_self_loop(g,etype="self_loop")#这时候 就要考虑  到底加什么类型的边了  自连接也要考虑边类型   自连接的边的类型要属于已有边的类型之一
    ##异构图应该只可以是有向图 否则  srctype->etype-->destype  和 destype-->etype-->srctype不一样
    # print(dgl.to_bidirected(g))
    # print("canonical_etypes-->",g.canonical_etypes)
    return g


class HeteroRGCNLayer(nn.Module):
    def __init__(self, in_dim, out_dim, etypes):
        super(HeteroRGCNLayer, self).__init__()
        self.in_dim=in_dim
        self.out_dim=out_dim
        self.etypes=etypes
        self.id_to_etype={i:name for i,name in enumerate(etypes)}

        # quation 1
        self.fc  = nn.ModuleDict({
                name : nn.Linear(in_dim, out_dim, bias=False) for name in etypes#4种类型的边
            })
        # equation 2
        self.attn_fc =nn.ModuleDict({
                name : nn.Linear(2 * out_dim, 1, bias=False) for name in etypes#4种类型的边
            })
    def edge_attention(self, edges):
        etype = self.id_to_etype[int(edges.data['id'][0])]
        print("attention计算的边类型:",etype)
        z2 = torch.cat([edges.src['wsrc%s' % etype], edges.dst['wdst%s' % etype]], dim=1)#Zi||Zj
        a = self.attn_fc[etype](z2)
        return {'alpha': F.leaky_relu(a)}

    def message_func(self, edges):
        etype = self.id_to_etype[int(edges.data['id'][0])]
        print("message func传递的边类型:",etype)
        return {'v': edges.src['wsrc%s' % etype], 'alpha': edges.data['alpha']}

    def reduce_func(self, nodes):
        alpha = F.softmax(nodes.mailbox['alpha'], dim=1)
        h = torch.sum(alpha * nodes.mailbox['v'], dim=1)
        print("reduce 执行了...")
        return {'h': h}


    def forward(self, G):
        # 为每个边加上id  但是只是为了传edges时候 知道是哪类边,用哪个attn_fc
        G.edge_dict = {}
        for etype in G.etypes:
            G.edge_dict[etype] = len(G.edge_dict)
            G.edges[etype].data['id'] = torch.ones(G.number_of_edges(etype), dtype=torch.long) * G.edge_dict[etype]

        #获取feat_dict
        feat_dict={}
        feat_dict["sentence"] = G.nodes["sentence"].data["feats"]
        feat_dict["word"] = G.nodes["word"].data["feats"]

        #计算每个attention系数
        # funcs = {}
        for srctype, etype, dsttype in G.canonical_etypes:

            print("循环正在计算:",srctype,"---", etype,"---" ,dsttype)
            Wh = self.fc[etype](feat_dict[srctype])
            G.nodes[srctype].data['wsrc%s' % etype] = Wh
            Wh = self.fc[etype](feat_dict[dsttype])
            G.nodes[dsttype].data['wdst%s' % etype] = Wh
            G.apply_edges(func=self.edge_attention, etype=etype)
        G.multi_update_all({etype: (self.message_func, self.reduce_func) for etype in G.edge_dict.keys()}, cross_reducer='mean')
        return {ntype : G.nodes[ntype].data['h'] for ntype in G.ntypes}
        #     # 计算 W_r * h
        #     Wh = self.weight[etype](feat_dict[srctype])###经过线性层
        #     # Wh=G.nodes[srctype].data["feats"]######################################################测试用
        #     # 将其存入图中以便于信息传递
        #     G.nodes[srctype].data['Wh_%s' % etype] = Wh
        #     # 指定每个关系的消息传递函数:(message_func,reduce_func)。
        #     # 请注意,结果将保存到相同的目标特征“ h”,这暗示了聚合的类型明智的约简。
        #     funcs[(srctype, etype, dsttype)] = (fn.copy_u('Wh_%s' % etype, 'm'), fn.sum('m', 'h'))
        # # 第一个参数是每个关系的消息传递函数(message passing functions)
        # # 第二个是类型明智的reduce functions,可以是“ sum”,“ max”,“ min”,“ mean”,“ stack”
        # G.multi_update_all(funcs, 'sum')
        # # 返回更新的节点特征(以字典形式表示)
        # return {ntype : G.nodes[ntype].data['h'] for ntype in G.ntypes}




hrgcn=HeteroRGCNLayer(in_dim=768,out_dim=256,etypes=["s_s","s_w","w_s","w_w"])



def merge_hetero_graph(heterograph_list,add_new_edge=True):
    bg=dgl.batch(heterograph_list)
    if not add_new_edge:
        data_dict={}
        for srctype, etype, dsttype in bg.canonical_etypes:
            data_dict[(srctype,etype,dsttype)]=bg.edges(etype=etype)
        new_g=dgl.heterograph(data_dict)
        new_g.nodes["sentence"].data["feats"] = bg.nodes["sentence"].data["feats"]
        new_g.nodes["word"].data["feats"] = bg.nodes["word"].data["feats"]

    else:
        data_dict = {}
        for srctype, etype, dsttype in bg.canonical_etypes:
            data_dict[(srctype, etype, dsttype)] = bg.edges(etype=etype)
        #[0,1,2,3,4]
        node_list=bg.nodes("sentence").numpy().tolist()
        #[0,1,2,3,1,2,3,4]
        s_s_src=node_list[:-1]+node_list[1:]
        #[1,2,3,4,0,1,2,3]
        s_s_dst=node_list[1:]+node_list[:-1]
        (src,dst)=data_dict[("sentence","s_s","sentence")]
        s_s_src=s_s_src+src.numpy().tolist()
        s_s_dst=s_s_dst+dst.numpy().tolist()
        data_dict[("sentence", "s_s", "sentence")]=(torch.tensor(s_s_src),torch.tensor(s_s_dst))
        new_g = dgl.heterograph(data_dict)
        new_g.nodes["sentence"].data["feats"]=bg.nodes["sentence"].data["feats"]
        new_g.nodes["word"].data["feats"]=bg.nodes["word"].data["feats"]

    return new_g

g1=build_graph()
 g2=build_graph()
bg=merge_hetero_graph([g1,g1,g1])
print(bg)

for srctype, etype, dsttype in g1.canonical_etypes:
    print(etype,":",g1.edges(etype=etype))
res=hrgcn(g1)
print("res====",res.keys())
'''
return {ntype : G.nodes[ntype].data['h'] for ntype in G.ntypes}
bg.edges(etype=etype)
print("node type",G.ntypes) # ['sentence', 'word']
'''



评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值