DGL 的GATConv报错:Expect number of features to match number of nodes (len(u)). Got 2397 and 799 instead

使用DGL的GATConv层,居然意外的出现如下错误:

dgl._ffi.base.DGLError: Expect number of features to match number of nodes (len(u)). Got 2397 and 799 instead.

注意到799是节点数,而2390刚好是799的3倍,这个3恰好又是num_heads的数值。因此
GATConv的返回值的shape为: ( N , H , M ) (N,H,M) (N,H,M) ,其中 N N N 是节点个数, H H H 是特征长度,而 M M M是头的数目。
当不做任何处理,DGL会默认对返回的矩阵做reshape,reshape的目标是(-1,H) 于是矩阵的行数就变成了 N × M N\times M N×M 了,此时就不对了。

解决方法:对GATConv的返回值执行一次flatten:

def forward(g):
···
        for layer in self.layers:
            pkt_length_matrix = layer(g,pkt_length_matrix.to(th.device(self.device)))
            arv_time_matrix = layer(g,arv_time_matrix.to(th.device(self.device)))
            if self.layer_type =='GAT':
                pkt_length_matrix = th.flatten(pkt_length_matrix,1)
                arv_time_matrix= th.flatten(arv_time_matrix,1)
···

同时,下一层GATConv的in_feat设置为上一层的out_feat × \times × num_heads。
这个就可以了。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值