使用DGL的GATConv层,居然意外的出现如下错误:
dgl._ffi.base.DGLError: Expect number of features to match number of nodes (len(u)). Got 2397 and 799 instead.
注意到799是节点数,而2390刚好是799的3倍,这个3恰好又是num_heads的数值。因此
GATConv的返回值的shape为:
(
N
,
H
,
M
)
(N,H,M)
(N,H,M) ,其中
N
N
N 是节点个数,
H
H
H 是特征长度,而
M
M
M是头的数目。
当不做任何处理,DGL会默认对返回的矩阵做reshape,reshape的目标是(-1,H) 于是矩阵的行数就变成了
N
×
M
N\times M
N×M 了,此时就不对了。
解决方法:对GATConv的返回值执行一次flatten:
def forward(g):
···
for layer in self.layers:
pkt_length_matrix = layer(g,pkt_length_matrix.to(th.device(self.device)))
arv_time_matrix = layer(g,arv_time_matrix.to(th.device(self.device)))
if self.layer_type =='GAT':
pkt_length_matrix = th.flatten(pkt_length_matrix,1)
arv_time_matrix= th.flatten(arv_time_matrix,1)
···
同时,下一层GATConv的in_feat设置为上一层的out_feat
×
\times
× num_heads。
这个就可以了。