代码:
a = torch.randn(2, 3)
a1 =a.to_sparse().requires_grad_(True)
a2 = torch.randn(5,2,3).to_sparse()
a3 = a.to_sparse().requires_grad_(False)
b = torch.randn(3, 2)
b1 =b.to_sparse().requires_grad_(True)
b2 = torch.randn(5,3,32).to_sparse()
b3 = b.to_sparse().requires_grad_(False)
y1 = torch.sparse.mm(a,b) # 两个dense矩阵
y2 = torch.sparse.mm(a1,b1) # 两个sparse矩阵
y3 = torch.sparse.mm(a1, b) #sparse,dense
#y4 = torch.sparse.mm(a, b1) # dense, sparse --不成立
# z0 = torch.spmm(a1, b) #sparse-有梯度.dense --不成立
z1 = torch.spmm(a3, b) #sparse-无梯度,dense ok
z2 = torch.spmm(a, b) #dense,dense ok
z3 =torch.spmm(a1, b1) #sparse-有梯度,sparse-有梯度 ---不成立
#z4 =torch.spmm(a, b1) #dense,sparse-有梯度 ---不成立
#z5 = torch.spmm(a, b3) # dense,sparse -无梯度 ----不成立
结论:
torch.sparse.mm
- dense × \times × dense
- sparse t i m e s times times sparse
- sparse t i m e s times times dense
torch.spmm
- dense t i m e s times times dense
- sparse t i m e s times times dense