params_count(model),model是nn.Module类型的,计算参数的数量。
SparseTensor:
rom torch_sparse import SparseTensor
adj = SparseTensor(row=edge_index[0], col=edge_index[1], value=...,
sparse_sizes=(num_nodes, num_nodes))
# value is optional and can be None
# Obtain different representations (COO, CSR, CSC):
row, col, value = adj.coo()
rowptr, col, value = adj.csr()
colptr, row, value = adj.csc()
adj = adj[:100, :100] # Slicing, indexing and masking support
adj = adj.set_diag() # Add diagonal entries
adj_t = adj.t() # Transpose
out = adj.matmul(x) # Sparse-dense matrix multiplication
adj = adj.matmul(adj) # Sparse-sparse matrix multiplication
# Creating SparseTensor instances:
adj = SparseTensor.from_dense(mat)
adj = SparseTensor.eye(100, 100)
adj = SparseTensor.from_scipy(mat)
支持sparse-sparse,sparse-dense矩阵乘法。
在消息传递的时候,如果用到了sparseTensor,那么需要经过转置才可以。比如下面的例子中,分别将x,edge_index和 x,adj.t作为conv的输入。
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
import torch_geometric.transforms as T
from torch_geometric.datasets import Planetoid
dataset = Planetoid("Planetoid", name="Cora", transform=T.ToSparseTensor())
data = dataset[0]
>>> Data(adj_t=[2708, 2708, nnz=10556], x=[2708, 1433], y=[2708], ...)
class GNN(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = GCNConv(dataset.num_features, 16, cached=True)
self.conv2 = GCNConv(16, dataset.num_classes, cached=True)
def forward(self, x, adj_t):
x = self.conv1(x, adj_t)
x = F.relu(x)
x = self.conv2(x, adj_t)
return F.log_softmax(x, dim=1)
model = GNN()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
def train(data):
model.train()
optimizer.zero_grad()
out = model(data.x, data.adj_t)
loss = F.nll_loss(out, data.y)
loss.backward()
optimizer.step()
return float(loss)
for epoch in range(1, 201):
loss = train(data)
当edge_weight或者edge_attr需要考虑的时候,需要显式将数据转换成sparseTensor。
conv = GMMConv(16, 32, dim=3) adj = SparseTensor(row=edge_index[0], col=edge_index[1], value=edge_attr) out = conv(x, adj.t())
稀疏矩阵adj转回edge-index可以用下面的方法:
row, col, edge_attr = adj_t.t().coo() edge_index = torch.stack([row, col], dim=0)
参考原文档: