文章目录
edge_index转换成adj_t
1,单独将edge_index转换成adj_t: torch_sparse.SparseTensor
from torch_sparse import SparseTensor
edge_index = torch.tensor([[1,2,3,4], [0,0,0,0]])
adj_t = SparseTensor(row=edge_index[1], col=edge_index[0])
adj_t
>>> SparseTensor(row=tensor([0, 0, 0, 0]),
col=tensor([1, 2, 3, 4]),
size=(1, 5), nnz=4, density=80.00%)
SparseTensor
默认会对row
中的index
排序。
edge_index = torch.tensor([[0,0,0,0], [4,3,1,2]])
adj_t = SparseTensor(row=edge_index[1], col=edge_index[0])
adj_t
>>>SparseTensor(row=tensor([1, 2, 3, 4]),
col=tensor([0, 0, 0, 0]),
size=(5, 1), nnz=4, density=80.00%)
# row输入的[4,3,1,2]变成了[1, 2, 3, 4]
2,Data中的edge_index转换成adj_t:torch_geometric.transforms.ToSparseTensor
from torch_geometric.transforms import ToSparseTensor
from torch_geometric.data.data import Data
transform = ToSparseTensor()
x = torch.randn(5,3)
data = Data(x=x, edge_index=edge_index)
data
>>> Data(edge_index=[2, 4], x=[5, 3])
data = transform(data)
data
>>> Data(adj_t=[5, 5, nnz=4], x=[5, 3])
data.adj_t
>>> SparseTensor(row=tensor([0, 0, 0, 0]),
col=tensor([1, 2, 3, 4]),
size=(5, 5), nnz=4, density=16.00%)
关于pytorch geometric中的稀疏邻接矩阵为何要存成转置的形式,可以参考我的另一篇博文。https://blog.csdn.net/weixin_39925939/article/details/121331550
adj_t转换成edge_index
row, col, _ = data.adj_t.t().coo() #data is torch_geometric.data.data.Data
edge_index = torch.stack([row, col], axis=0)