layers.GraphAttentionLayer._prepare_attentional_mechanism_input
Wh.shape
Out[1]: torch.Size([2708, 8])
Wh_repeated_in_chunks.shape
Out[3]: torch.Size([7333264, 8])
N
Out[4]: 2708
N**2
Out[5]: 7333264
Wh_repeated_in_chunks.shape == Wh_repeated_alternating.shape == (N * N, out_features)
目的其实是把两个拼起来
# e1 || e1
# e1 || e2
# e1 || e3
# ...
# e1 || eN
# e2 || e1
# e2 || e2
# e2 || e3
# ...
# e2 || eN
# ...
# eN || e1
# eN || e2
# eN || e3
# ...
# eN || eN
all_combinations_matrix = torch.cat([Wh_repeated_in_chunks, Wh_repeated_alternating], dim